diff --git a/reading_machine/CMakeLists.txt b/reading_machine/CMakeLists.txt index 5e8450d418ac983587776b95949a9375c96e9d39..a315e4e92113fb899c3caae3337fab8176a92273 100644 --- a/reading_machine/CMakeLists.txt +++ b/reading_machine/CMakeLists.txt @@ -2,4 +2,5 @@ FILE(GLOB SOURCES src/*.cpp) add_library(reading_machine STATIC ${SOURCES}) target_link_libraries(reading_machine common) +target_link_libraries(reading_machine torch_modules) diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 41cac659cc109d9aa11a587d6f25d01ab8b0b3a3..ce61d5a6c850dc5efab0a9cd70acb2282168717c 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -2,8 +2,8 @@ #define CLASSIFIER__H #include <string> -#include <torch/torch.h> #include "TransitionSet.hpp" +#include "MLP.hpp" class Classifier { @@ -11,6 +11,7 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; + MLP nn{nullptr}; public : diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index efd42cf8567a5ba2e0b99d4e0b60d9e2e755784a..464286f2ca3370381bc068fee911a889e038959a 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -91,6 +91,8 @@ class Config std::size_t getStack(int relativeIndex) const; bool hasHistory(int relativeIndex) const; bool hasStack(int relativeIndex) const; + String getState() const; + void setState(const std::string state); }; diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 7a37a0c8eae3c6dad996ec3bc221ca806c792b5b..339e25a95f888bbf6fb3c93d312a6997c485bac7 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -9,6 +9,7 @@ class ReadingMachine private : std::string name; + std::function<std::pair<std::string, int>(const Config & config)> strategy; std::unique_ptr<Classifier> classifier; public : diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd42d4f20149202a3d05550c9a6bde02f78acf5a --- /dev/null +++ b/reading_machine/include/Strategy.hpp @@ -0,0 +1,6 @@ +#ifndef STRATEGY__H +#define STRATEGY__H + +#include "Config.hpp" + +#endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 29791a5033ae5cfc15c1339d4ae3ffb732b71c51..47100c74920d8fdf399171a55b34f281ae28adc7 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -4,5 +4,6 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); + this->nn = MLP(topology); } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 85092e69b321874bac72cf9ee9488755f73c536f..919ab2517da772a56539b3a3d8a22cae802f62ef 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -330,3 +330,13 @@ bool Config::hasStack(int relativeIndex) const return relativeIndex > 0 && relativeIndex < (int)stack.size(); } +Config::String Config::getState() const +{ + return state; +} + +void Config::setState(const std::string state) +{ + this->state = state; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9160a890b454c52b4b7773218bcc1e0a3dca5ccf..9528bb8bd5e89a8ac449989df98bdf7d587746f4 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -1,9 +1,11 @@ #include "ReadingMachine.hpp" #include "util.hpp" +#include "Strategy.hpp" ReadingMachine::ReadingMachine(const std::string & filename) { std::regex nameRegex("Name : (.+)[\n]"); + std::regex strategyRegex("Strategy : (.+)[\n]"); std::regex classifierRegex("Classifier : (.+) (.+) (.+)[\n]"); std::FILE * file = std::fopen(filename.c_str(), "r"); @@ -18,10 +20,9 @@ ReadingMachine::ReadingMachine(const std::string & filename) { if (util::doIfNameMatch(nameRegex, buffer, [this](auto sm){name = sm[1];})) continue; - if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm) - { - classifier.reset(new Classifier(sm[1], sm[2], sm[3])); - })) + if (util::doIfNameMatch(strategyRegex, buffer, [this](auto sm){})) + continue; + if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm){classifier.reset(new Classifier(sm[1], sm[2], sm[3]));})) continue; } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} } diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4beb7af4ea29f9e99adf41b18cf6344d65bd5fb0 --- /dev/null +++ b/reading_machine/src/Strategy.cpp @@ -0,0 +1,3 @@ +#include "Strategy.hpp" +#include "util.hpp" + diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp index ddbd44fee79bda0ceb2c518bfb500afd2a8a655b..90bde50aea779f7d6d3188d98f7058053c0ec8e2 100644 --- a/torch_modules/include/MLP.hpp +++ b/torch_modules/include/MLP.hpp @@ -1,9 +1,14 @@ #ifndef MLP__H #define MLP__H -class MLPImpl +#include <torch/torch.h> + +class MLPImpl : torch::nn::Module { + public : + MLPImpl(const std::string & topology); }; +TORCH_MODULE(MLP); #endif diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index 0a5a3201bbd1a1d99fb99ef97be0bf92528e8052..182e880a6df2c60b131bc71af463e47d901b64b8 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -1 +1,8 @@ #include "MLP.hpp" +#include <regex> + +MLPImpl::MLPImpl(const std::string & topology) +{ + +} +