Skip to content
Snippets Groups Projects
Commit 742b3673 authored by Franck Dary's avatar Franck Dary
Browse files

Created strategy file

parent 8e616e06
No related branches found
No related tags found
No related merge requests found
......@@ -2,4 +2,5 @@ FILE(GLOB SOURCES src/*.cpp)
add_library(reading_machine STATIC ${SOURCES})
target_link_libraries(reading_machine common)
target_link_libraries(reading_machine torch_modules)
......@@ -2,8 +2,8 @@
#define CLASSIFIER__H
#include <string>
#include <torch/torch.h>
#include "TransitionSet.hpp"
#include "MLP.hpp"
class Classifier
{
......@@ -11,6 +11,7 @@ class Classifier
std::string name;
std::unique_ptr<TransitionSet> transitionSet;
MLP nn{nullptr};
public :
......
......@@ -91,6 +91,8 @@ class Config
std::size_t getStack(int relativeIndex) const;
bool hasHistory(int relativeIndex) const;
bool hasStack(int relativeIndex) const;
String getState() const;
void setState(const std::string state);
};
......
......@@ -9,6 +9,7 @@ class ReadingMachine
private :
std::string name;
std::function<std::pair<std::string, int>(const Config & config)> strategy;
std::unique_ptr<Classifier> classifier;
public :
......
#ifndef STRATEGY__H
#define STRATEGY__H
#include "Config.hpp"
#endif
......@@ -4,5 +4,6 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c
{
this->name = name;
this->transitionSet.reset(new TransitionSet(tsFile));
this->nn = MLP(topology);
}
......@@ -330,3 +330,13 @@ bool Config::hasStack(int relativeIndex) const
return relativeIndex > 0 && relativeIndex < (int)stack.size();
}
Config::String Config::getState() const
{
return state;
}
void Config::setState(const std::string state)
{
this->state = state;
}
#include "ReadingMachine.hpp"
#include "util.hpp"
#include "Strategy.hpp"
ReadingMachine::ReadingMachine(const std::string & filename)
{
std::regex nameRegex("Name : (.+)[\n]");
std::regex strategyRegex("Strategy : (.+)[\n]");
std::regex classifierRegex("Classifier : (.+) (.+) (.+)[\n]");
std::FILE * file = std::fopen(filename.c_str(), "r");
......@@ -18,10 +20,9 @@ ReadingMachine::ReadingMachine(const std::string & filename)
{
if (util::doIfNameMatch(nameRegex, buffer, [this](auto sm){name = sm[1];}))
continue;
if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm)
{
classifier.reset(new Classifier(sm[1], sm[2], sm[3]));
}))
if (util::doIfNameMatch(strategyRegex, buffer, [this](auto sm){}))
continue;
if (util::doIfNameMatch(classifierRegex, buffer, [this](auto sm){classifier.reset(new Classifier(sm[1], sm[2], sm[3]));}))
continue;
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));}
}
......
#include "Strategy.hpp"
#include "util.hpp"
#ifndef MLP__H
#define MLP__H
class MLPImpl
#include <torch/torch.h>
class MLPImpl : torch::nn::Module
{
public :
MLPImpl(const std::string & topology);
};
TORCH_MODULE(MLP);
#endif
#include "MLP.hpp"
#include <regex>
MLPImpl::MLPImpl(const std::string & topology)
{
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment