From 742b3673715259571c9ce4eb88ec1c3417f1ec91 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 6 Jan 2020 18:25:10 +0100 Subject: [PATCH] Created strategy file --- reading_machine/CMakeLists.txt | 1 + reading_machine/include/Classifier.hpp | 3 ++- reading_machine/include/Config.hpp | 2 ++ reading_machine/include/ReadingMachine.hpp | 1 + reading_machine/include/Strategy.hpp | 6 ++++++ reading_machine/src/Classifier.cpp | 1 + reading_machine/src/Config.cpp | 10 ++++++++++ reading_machine/src/ReadingMachine.cpp | 9 +++++---- reading_machine/src/Strategy.cpp | 3 +++ torch_modules/include/MLP.hpp | 7 ++++++- torch_modules/src/MLP.cpp | 7 +++++++ 11 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 reading_machine/include/Strategy.hpp create mode 100644 reading_machine/src/Strategy.cpp diff --git a/reading_machine/CMakeLists.txt b/reading_machine/CMakeLists.txt index 5e8450d..a315e4e 100644 --- a/reading_machine/CMakeLists.txt +++ b/reading_machine/CMakeLists.txt @@ -2,4 +2,5 @@ FILE(GLOB SOURCES src/*.cpp) add_library(reading_machine STATIC ${SOURCES}) target_link_libraries(reading_machine common) +target_link_libraries(reading_machine torch_modules) diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 41cac65..ce61d5a 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -2,8 +2,8 @@ #define CLASSIFIER__H #include <string> -#include <torch/torch.h> #include "TransitionSet.hpp" +#include "MLP.hpp" class Classifier { @@ -11,6 +11,7 @@ class Classifier std::string name; std::unique_ptr<TransitionSet> transitionSet; + MLP nn{nullptr}; public : diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index efd42cf..464286f 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -91,6 +91,8 @@ class Config std::size_t getStack(int relativeIndex) const; bool hasHistory(int relativeIndex) const; bool hasStack(int relativeIndex) const; + String getState() const; + void setState(const std::string state); }; diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 7a37a0c..339e25a 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -9,6 +9,7 @@ class ReadingMachine private : std::string name; + std::function<std::pair<std::string, int>(const Config & config)> strategy; std::unique_ptr<Classifier> classifier; public : diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp new file mode 100644 index 0000000..bd42d4f --- /dev/null +++ b/reading_machine/include/Strategy.hpp @@ -0,0 +1,6 @@ +#ifndef STRATEGY__H +#define STRATEGY__H + +#include "Config.hpp" + +#endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 29791a5..47100c7 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -4,5 +4,6 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c { this->name = name; this->transitionSet.reset(new TransitionSet(tsFile)); + this->nn = MLP(topology); } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 85092e6..919ab25 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -330,3 +330,13 @@ bool Config::hasStack(int relativeIndex) const return relativeIndex > 0 && relativeIndex < (int)stack.size(); } +Config::String Config::getState() const +{ + return state; +} + +void Config::setState(const std::string state) +{ + this->state = state; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 9160a89..9528bb8 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -1,9 +1,11 @@ #include "ReadingMachine.hpp" #include "util.hpp" +#include "Strategy.hpp" ReadingMachine::ReadingMachine(const std::string & filename) { std::regex nameRegex("Name : (.+)[\n]"); + std::regex strategyRegex("Strategy : (.+)[\n]"); std::regex classifierRegex("Classifier : (.+) (.+) (.+)[\n]"); std::FILE * file = std::fopen(filename.c_str(), "r"); @@ -18,10 +20,9 @@ ReadingMachine::ReadingMachine(const std::string & filename) { if (util::doIfNameMatch(nameRegex, buffer, [this](auto sm){name = sm[1];})) continue; - if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm) - { - classifier.reset(new Classifier(sm[1], sm[2], sm[3])); - })) + if (util::doIfNameMatch(strategyRegex, buffer, [this](auto sm){})) + continue; + if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm){classifier.reset(new Classifier(sm[1], sm[2], sm[3]));})) continue; } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} } diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp new file mode 100644 index 0000000..4beb7af --- /dev/null +++ b/reading_machine/src/Strategy.cpp @@ -0,0 +1,3 @@ +#include "Strategy.hpp" +#include "util.hpp" + diff --git a/torch_modules/include/MLP.hpp b/torch_modules/include/MLP.hpp index ddbd44f..90bde50 100644 --- a/torch_modules/include/MLP.hpp +++ b/torch_modules/include/MLP.hpp @@ -1,9 +1,14 @@ #ifndef MLP__H #define MLP__H -class MLPImpl +#include <torch/torch.h> + +class MLPImpl : torch::nn::Module { + public : + MLPImpl(const std::string & topology); }; +TORCH_MODULE(MLP); #endif diff --git a/torch_modules/src/MLP.cpp b/torch_modules/src/MLP.cpp index 0a5a320..182e880 100644 --- a/torch_modules/src/MLP.cpp +++ b/torch_modules/src/MLP.cpp @@ -1 +1,8 @@ #include "MLP.hpp" +#include <regex> + +MLPImpl::MLPImpl(const std::string & topology) +{ + +} + -- GitLab