Skip to content
Snippets Groups Projects
Commit b9a3230c authored by Franck Dary's avatar Franck Dary
Browse files

Made splitwords transition set optional in definition of reading machine

parent 050839ed
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (debug) if (debug)
config.printForDebug(stderr); config.printForDebug(stderr);
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto dictState = machine.getDict(config.getState()).getState(); auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back(); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
......
...@@ -26,7 +26,7 @@ class ReadingMachine ...@@ -26,7 +26,7 @@ class ReadingMachine
std::map<std::string, Dict> dicts; std::map<std::string, Dict> dicts;
std::set<std::string> predicted; std::set<std::string> predicted;
std::unique_ptr<TransitionSet> splitWordTransitionSet; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
private : private :
...@@ -39,6 +39,7 @@ class ReadingMachine ...@@ -39,6 +39,7 @@ class ReadingMachine
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet(); TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const;
Strategy & getStrategy(); Strategy & getStrategy();
Dict & getDict(const std::string & state); Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts(); std::map<std::string, Dict> & getDicts();
......
...@@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
--curLine; --curLine;
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm) util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
{ {
this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1))); this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
curLine++;
}); });
if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm) if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
...@@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet() ...@@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet()
return classifier->getTransitionSet(); return classifier->getTransitionSet();
} }
bool ReadingMachine::hasSplitWordTransitionSet() const
{
return splitWordTransitionSet.get() != nullptr;
}
TransitionSet & ReadingMachine::getSplitWordTransitionSet() TransitionSet & ReadingMachine::getSplitWordTransitionSet()
{ {
return *splitWordTransitionSet; return *splitWordTransitionSet;
......
...@@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: ...@@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
if (debug) if (debug)
config.printForDebug(stderr); config.printForDebug(stderr);
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition) if (!transition)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment