Skip to content
Snippets Groups Projects
Commit b9bb67a5 authored by Franck Dary's avatar Franck Dary
Browse files

Strategy is now part of Config

parent ef5992ee
Branches
No related tags found
No related merge requests found
...@@ -7,12 +7,13 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -7,12 +7,13 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement)
{ {
constexpr float beamThreshold = 0.1;
constexpr int printInterval = 50;
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false); machine.trainMode(false);
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset();
constexpr int printInterval = 50;
int nbExamplesProcessed = 0; int nbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now(); auto pastTime = std::chrono::high_resolution_clock::now();
...@@ -22,16 +23,15 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -22,16 +23,15 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
try try
{ {
baseConfig.addPredicted(machine.getPredicted());
baseConfig.setState(machine.getStrategy().getInitialState());
for (unsigned int i = 0; i < beamSize; i++) for (unsigned int i = 0; i < beamSize; i++)
{ {
beam.emplace_back(baseConfig); beam.emplace_back(baseConfig);
beam.back().setStrategy(machine.getStrategyDefinition());
beam.back().addPredicted(machine.getPredicted());
beam.back().setState(beam.back().getStrategy().getInitialState());
endFlag.emplace_back(false); endFlag.emplace_back(false);
} }
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
while (true) while (true)
{ {
if (machine.hasSplitWordTransitionSet()) if (machine.hasSplitWordTransitionSet())
...@@ -41,6 +41,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -41,6 +41,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
std::vector<std::vector<int>> appliableTransitions; std::vector<std::vector<int>> appliableTransitions;
for (auto & c : beam) for (auto & c : beam)
{ {
machine.getClassifier()->setState(c.getState());
appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c)); appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c));
c.setAppliableTransitions(appliableTransitions.back()); c.setAppliableTransitions(appliableTransitions.back());
} }
...@@ -61,6 +62,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -61,6 +62,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++) for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++)
{ {
auto & c = beam[beamIndex]; auto & c = beam[beamIndex];
machine.getClassifier()->setState(c.getState());
c.printForDebug(stderr); c.printForDebug(stderr);
auto softmaxed = torch::softmax(predictions[beamIndex],-1); auto softmaxed = torch::softmax(predictions[beamIndex],-1);
std::vector<std::pair<float,std::string>> toPrint; std::vector<std::pair<float,std::string>> toPrint;
...@@ -83,8 +85,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -83,8 +85,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
if (endFlag[beamIndex]) if (endFlag[beamIndex])
continue; continue;
auto & c = beam[beamIndex]; auto & c = beam[beamIndex];
machine.getClassifier()->setState(c.getState());
int chosenTransition = -1; int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min(); float bestScore = std::numeric_limits<float>::min();
auto softmaxed = torch::softmax(predictions[beamIndex], 0);
std::vector<int> consideredTransitions;
try try
{ {
...@@ -96,7 +101,10 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -96,7 +101,10 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
chosenTransition = i; chosenTransition = i;
bestScore = score; bestScore = score;
} }
if (softmaxed[i].item<float>() >= beamThreshold)
consideredTransitions.emplace_back(i);
} }
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
if (chosenTransition == -1) if (chosenTransition == -1)
...@@ -110,7 +118,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -110,7 +118,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
transition->apply(c); transition->apply(c);
c.addToHistory(transition->getName()); c.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(c, transition->getName()); auto movement = c.getStrategy().getMovement(c, transition->getName());
if (debug) if (debug)
{ {
//TODO improve this for beam search //TODO improve this for beam search
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <boost/circular_buffer.hpp> #include <boost/circular_buffer.hpp>
#include "util.hpp" #include "util.hpp"
#include "Dict.hpp" #include "Dict.hpp"
#include "Strategy.hpp"
class Transition; class Transition;
...@@ -49,6 +50,7 @@ class Config ...@@ -49,6 +50,7 @@ class Config
int currentWordId{0}; int currentWordId{0};
std::vector<Transition *> appliableSplitTransitions; std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions; std::vector<int> appliableTransitions;
std::shared_ptr<Strategy> strategy;
protected : protected :
...@@ -156,6 +158,8 @@ class Config ...@@ -156,6 +158,8 @@ class Config
const std::vector<Transition *> & getAppliableSplitTransitions() const; const std::vector<Transition *> & getAppliableSplitTransitions() const;
const std::vector<int> & getAppliableTransitions() const; const std::vector<int> & getAppliableTransitions() const;
bool isExtraColumn(const std::string & colName) const; bool isExtraColumn(const std::string & colName) const;
void setStrategy(std::vector<std::string> & strategyDefinition);
Strategy & getStrategy();
}; };
#endif #endif
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#include <filesystem> #include <filesystem>
#include <memory> #include <memory>
#include "Classifier.hpp" #include "Classifier.hpp"
#include "Strategy.hpp"
class ReadingMachine class ReadingMachine
{ {
...@@ -19,7 +18,7 @@ class ReadingMachine ...@@ -19,7 +18,7 @@ class ReadingMachine
std::string name; std::string name;
std::filesystem::path path; std::filesystem::path path;
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy; std::vector<std::string> strategyDefinition;
std::set<std::string> predicted; std::set<std::string> predicted;
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
...@@ -36,7 +35,7 @@ class ReadingMachine ...@@ -36,7 +35,7 @@ class ReadingMachine
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet(); TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const; bool hasSplitWordTransitionSet() const;
Strategy & getStrategy(); std::vector<std::string> & getStrategyDefinition();
Classifier * getClassifier(); Classifier * getClassifier();
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
......
#ifndef STRATEGY__H #ifndef STRATEGY__H
#define STRATEGY__H #define STRATEGY__H
#include "Config.hpp" #include <string>
#include <vector>
class Config;
class Strategy class Strategy
{ {
......
...@@ -727,3 +727,16 @@ std::size_t Config::getStackSize() const ...@@ -727,3 +727,16 @@ std::size_t Config::getStackSize() const
return stack.size(); return stack.size();
} }
void Config::setStrategy(std::vector<std::string> & strategyDefinition)
{
strategy.reset(new Strategy(strategyDefinition));
}
Strategy & Config::getStrategy()
{
if (strategy.get() == nullptr)
util::myThrow("strategy was not set");
return *strategy.get();
}
...@@ -90,7 +90,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -90,7 +90,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
{ {
std::vector<std::string> strategyDefinition; strategyDefinition.clear();
if (lines[curLine] != "{") if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
...@@ -100,7 +100,6 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -100,7 +100,6 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
break; break;
strategyDefinition.emplace_back(lines[curLine]); strategyDefinition.emplace_back(lines[curLine]);
} }
strategy.reset(new Strategy(strategyDefinition));
})) }))
util::myThrow("No Strategy specified"); util::myThrow("No Strategy specified");
...@@ -122,9 +121,9 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet() ...@@ -122,9 +121,9 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet()
return *splitWordTransitionSet; return *splitWordTransitionSet;
} }
Strategy & ReadingMachine::getStrategy() std::vector<std::string> & ReadingMachine::getStrategyDefinition()
{ {
return *strategy; return strategyDefinition;
} }
Classifier * ReadingMachine::getClassifier() Classifier * ReadingMachine::getClassifier()
......
#include "Strategy.hpp" #include "Strategy.hpp"
#include "Config.hpp"
Strategy::Strategy(std::vector<std::string> definition) Strategy::Strategy(std::vector<std::string> definition)
{ {
......
...@@ -41,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -41,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
std::filesystem::create_directories(dir); std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
machine.getStrategy().reset(); config.setStrategy(machine.getStrategyDefinition());
config.setState(machine.getStrategy().getInitialState()); config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState());
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile); bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
...@@ -132,7 +132,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -132,7 +132,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
transition->apply(config); transition->apply(config);
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName());
if (debug) if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement) if (movement == Strategy::endMovement)
...@@ -291,9 +291,9 @@ void Trainer::fillDicts(SubConfig & config, bool debug) ...@@ -291,9 +291,9 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
machine.getStrategy().reset(); config.setStrategy(machine.getStrategyDefinition());
config.setState(machine.getStrategy().getInitialState()); config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState());
while (true) while (true)
{ {
...@@ -325,7 +325,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug) ...@@ -325,7 +325,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
goldTransition->apply(config); goldTransition->apply(config);
config.addToHistory(goldTransition->getName()); config.addToHistory(goldTransition->getName());
auto movement = machine.getStrategy().getMovement(config, goldTransition->getName()); auto movement = config.getStrategy().getMovement(config, goldTransition->getName());
if (debug) if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement) if (movement == Strategy::endMovement)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment