Skip to content
Snippets Groups Projects
Commit b9a3230c authored by Franck Dary's avatar Franck Dary
Browse files

Made splitwords transition set optional in definition of reading machine

parent 050839ed
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,8 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (debug)
config.printForDebug(stderr);
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
......
......@@ -26,7 +26,7 @@ class ReadingMachine
std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
std::unique_ptr<TransitionSet> splitWordTransitionSet;
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
private :
......@@ -39,6 +39,7 @@ class ReadingMachine
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const;
Strategy & getStrategy();
Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
......
......@@ -69,9 +69,10 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
--curLine;
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine++], [this,path](auto sm)
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
{
this->splitWordTransitionSet.reset(new TransitionSet(path.parent_path() / sm.str(1)));
curLine++;
});
if (!util::doIfNameMatch(std::regex("Predictions : (.+)"), lines[curLine++], [this](auto sm)
......@@ -95,6 +96,11 @@ TransitionSet & ReadingMachine::getTransitionSet()
return classifier->getTransitionSet();
}
bool ReadingMachine::hasSplitWordTransitionSet() const
{
return splitWordTransitionSet.get() != nullptr;
}
TransitionSet & ReadingMachine::getSplitWordTransitionSet()
{
return *splitWordTransitionSet;
......
......@@ -43,7 +43,8 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
if (debug)
config.printForDebug(stderr);
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment