Skip to content
Snippets Groups Projects
Commit f0df0508 authored by Franck Dary's avatar Franck Dary
Browse files

Speed up oracle

parent 79837075
No related branches found
No related tags found
No related merge requests found
...@@ -12,8 +12,8 @@ class Transition ...@@ -12,8 +12,8 @@ class Transition
std::string name; std::string name;
std::string state; std::string state;
std::vector<Action> sequence; std::vector<Action> sequence;
std::function<int(const Config & config)> costDynamic; std::function<int(const Config & config, const std::map<std::string, int> & links)> costDynamic;
std::function<int(const Config & config)> costStatic; std::function<int(const Config & config, const std::map<std::string, int> & links)> costStatic;
std::function<bool(const Config & config)> precondition{[](const Config&){return true;}}; std::function<bool(const Config & config)> precondition{[](const Config&){return true;}};
private : private :
...@@ -64,8 +64,8 @@ class Transition ...@@ -64,8 +64,8 @@ class Transition
void apply(Config & config, float entropy); void apply(Config & config, float entropy);
void apply(Config & config); void apply(Config & config);
bool appliable(const Config & config) const; bool appliable(const Config & config) const;
int getCostDynamic(const Config & config) const; int getCostDynamic(const Config & config, const std::map<std::string, int> & links) const;
int getCostStatic(const Config & config) const; int getCostStatic(const Config & config, const std::map<std::string, int> & links) const;
const std::string & getName() const; const std::string & getName() const;
}; };
......
...@@ -28,6 +28,7 @@ class TransitionSet ...@@ -28,6 +28,7 @@ class TransitionSet
Transition * getTransition(std::size_t index); Transition * getTransition(std::size_t index);
Transition * getTransition(const std::string & name); Transition * getTransition(const std::string & name);
std::size_t size() const; std::size_t size() const;
std::map<std::string, int> computeLinks(const Config & c);
}; };
#endif #endif
...@@ -112,6 +112,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent ...@@ -112,6 +112,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
try try
{ {
std::map<std::string, int> id2index; std::map<std::string, int> id2index;
std::map<int, std::vector<std::string>> childs;
int firstIndexOfSequence = getNbLines()-1; int firstIndexOfSequence = getNbLines()-1;
for (int i = (int)getNbLines()-1; has(0, i, 0); --i) for (int i = (int)getNbLines()-1; has(0, i, 0); --i)
{ {
...@@ -125,6 +126,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent ...@@ -125,6 +126,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
id2index[getConst(idColName, i, 0)] = i; id2index[getConst(idColName, i, 0)] = i;
} }
if (hasColIndex(headColName)) if (hasColIndex(headColName))
{
for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i) for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i)
{ {
if (!isToken(i)) if (!isToken(i))
...@@ -133,8 +135,14 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent ...@@ -133,8 +135,14 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
if (head == "0") if (head == "0")
head = "-1"; head = "-1";
else else
{
childs[id2index[head]].emplace_back(fmt::format("{}",i));
head = std::to_string(id2index[head]); head = std::to_string(id2index[head]);
}
} }
for (auto it : childs)
get(Config::childsColName, it.first, 0) = util::join("|", it.second);
}
get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; get(EOSColName, getNbLines()-1, 0) = EOSSymbol1;
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
......
...@@ -175,17 +175,17 @@ bool Transition::appliable(const Config & config) const ...@@ -175,17 +175,17 @@ bool Transition::appliable(const Config & config) const
return true; return true;
} }
int Transition::getCostDynamic(const Config & config) const int Transition::getCostDynamic(const Config & config, const std::map<std::string, int> & links) const
{ {
try {return costDynamic(config);} try {return costDynamic(config, links);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0; return 0;
} }
int Transition::getCostStatic(const Config & config) const int Transition::getCostStatic(const Config & config, const std::map<std::string, int> & links) const
{ {
try {return costStatic(config);} try {return costStatic(config, links);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0; return 0;
...@@ -203,7 +203,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string ...@@ -203,7 +203,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config) costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
...@@ -223,7 +223,7 @@ void Transition::initWriteScore(std::string colName, std::string object, std::st ...@@ -223,7 +223,7 @@ void Transition::initWriteScore(std::string colName, std::string object, std::st
sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue)); sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue));
costDynamic = [](const Config &) costDynamic = [](const Config &, const std::map<std::string, int> &)
{ {
return 0; return 0;
}; };
...@@ -238,7 +238,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in ...@@ -238,7 +238,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config) costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
...@@ -256,7 +256,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in ...@@ -256,7 +256,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
void Transition::initNothing() void Transition::initNothing()
{ {
costDynamic = [](const Config &) costDynamic = [](const Config &, const std::map<std::string, int> &)
{ {
return 0; return 0;
}; };
...@@ -268,7 +268,7 @@ void Transition::initIgnoreChar() ...@@ -268,7 +268,7 @@ void Transition::initIgnoreChar()
{ {
sequence.emplace_back(Action::ignoreCurrentCharacter()); sequence.emplace_back(Action::ignoreCurrentCharacter());
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0))); auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0)));
...@@ -286,7 +286,7 @@ void Transition::initEndWord() ...@@ -286,7 +286,7 @@ void Transition::initEndWord()
{ {
sequence.emplace_back(Action::endWord()); sequence.emplace_back(Action::endWord());
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0; return 0;
...@@ -304,7 +304,7 @@ void Transition::initAddCharToWord(int n) ...@@ -304,7 +304,7 @@ void Transition::initAddCharToWord(int n)
sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0));
sequence.emplace_back(Action::moveCharacterIndex(n)); sequence.emplace_back(Action::moveCharacterIndex(n));
costDynamic = [n](const Config & config) costDynamic = [n](const Config & config, const std::map<std::string, int> &)
{ {
if (!config.hasCharacter(config.getCharacterIndex()+n-1)) if (!config.hasCharacter(config.getCharacterIndex()+n-1))
return std::numeric_limits<int>::max(); return std::numeric_limits<int>::max();
...@@ -345,7 +345,7 @@ void Transition::initSplitWord(std::vector<std::string> words) ...@@ -345,7 +345,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
} }
sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
costDynamic = [words](const Config & config) costDynamic = [words](const Config & config, const std::map<std::string, int> &)
{ {
if (!config.isMultiword(config.getWordIndex())) if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max(); return std::numeric_limits<int>::max();
...@@ -367,14 +367,14 @@ void Transition::initSplit(int index) ...@@ -367,14 +367,14 @@ void Transition::initSplit(int index)
{ {
sequence.emplace_back(Action::split(index)); sequence.emplace_back(Action::split(index));
costDynamic = [index](const Config & config) costDynamic = [index](const Config & config, const std::map<std::string, int> & links)
{ {
auto & transitions = config.getAppliableSplitTransitions(); auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size()) if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max(); return std::numeric_limits<int>::max();
return transitions[index]->getCostDynamic(config); return transitions[index]->getCostDynamic(config, links);
}; };
costStatic = costDynamic; costStatic = costDynamic;
...@@ -384,25 +384,22 @@ void Transition::initEagerShift() ...@@ -384,25 +384,22 @@ void Transition::initEagerShift()
{ {
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> & links)
{ {
if (!config.isToken(config.getWordIndex())) if (!config.isToken(config.getWordIndex()))
return 0; return 0;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); return links.at("BufferStack");
}; };
costStatic = [](const Config &) costStatic = costDynamic;
{
return 0;
};
} }
void Transition::initGoldEagerShift() void Transition::initGoldEagerShift()
{ {
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
if (!config.isToken(config.getWordIndex())) if (!config.isToken(config.getWordIndex()))
return 0; return 0;
...@@ -410,7 +407,7 @@ void Transition::initGoldEagerShift() ...@@ -410,7 +407,7 @@ void Transition::initGoldEagerShift()
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
}; };
costStatic = [](const Config &) costStatic = [](const Config &, const std::map<std::string, int> &)
{ {
return 0; return 0;
}; };
...@@ -428,7 +425,7 @@ void Transition::initStandardShift() ...@@ -428,7 +425,7 @@ void Transition::initStandardShift()
{ {
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config &) costDynamic = [](const Config &, const std::map<std::string, int> &)
{ {
return 0; return 0;
}; };
...@@ -442,23 +439,23 @@ void Transition::initEagerLeft_rel(std::string label) ...@@ -442,23 +439,23 @@ void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex(); auto govIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); int cost = 0;
if (label != config.getConst(Config::deprelColName, depIndex, 0)) if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost; ++cost;
if (depGovIndex != std::to_string(govIndex)) if (depGovIndex != std::to_string(govIndex))
++cost; cost += links.at("StackRight");
return cost; return cost;
}; };
costStatic = [label](const Config & config) costStatic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
...@@ -477,7 +474,7 @@ void Transition::initGoldEagerLeft_rel(std::string label) ...@@ -477,7 +474,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex(); auto govIndex = config.getWordIndex();
...@@ -490,7 +487,7 @@ void Transition::initGoldEagerLeft_rel(std::string label) ...@@ -490,7 +487,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
return cost; return cost;
}; };
costStatic = [label](const Config & config) costStatic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
...@@ -522,7 +519,7 @@ void Transition::initStandardLeft_rel(std::string label) ...@@ -522,7 +519,7 @@ void Transition::initStandardLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1)); sequence.emplace_back(Action::popStack(1));
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(1); auto depIndex = config.getStack(1);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
...@@ -548,7 +545,7 @@ void Transition::initEagerLeft() ...@@ -548,7 +545,7 @@ void Transition::initEagerLeft()
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex(); auto govIndex = config.getWordIndex();
...@@ -562,7 +559,7 @@ void Transition::initEagerLeft() ...@@ -562,7 +559,7 @@ void Transition::initEagerLeft()
return cost; return cost;
}; };
costStatic = [](const Config & config) costStatic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
...@@ -581,24 +578,23 @@ void Transition::initEagerRight_rel(std::string label) ...@@ -581,24 +578,23 @@ void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
{ {
auto govIndex = config.getStack(0); auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); int cost = 0;
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0)) if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost; ++cost;
if (depGovIndex == std::to_string(govIndex)) if (depGovIndex != std::to_string(govIndex))
++cost; cost += links.at("BufferStack") + links.at("BufferRightHead");
return cost; return cost;
}; };
costStatic = [label](const Config & config) costStatic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto govIndex = config.getStack(0); auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
...@@ -617,7 +613,7 @@ void Transition::initGoldEagerRight_rel(std::string label) ...@@ -617,7 +613,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
...@@ -630,7 +626,7 @@ void Transition::initGoldEagerRight_rel(std::string label) ...@@ -630,7 +626,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
return cost; return cost;
}; };
costStatic = [label](const Config & config) costStatic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto govIndex = config.getStack(0); auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
...@@ -662,7 +658,7 @@ void Transition::initStandardRight_rel(std::string label) ...@@ -662,7 +658,7 @@ void Transition::initStandardRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{ {
auto govIndex = config.getStack(1); auto govIndex = config.getStack(1);
auto depIndex = config.getStack(0); auto depIndex = config.getStack(0);
...@@ -688,7 +684,7 @@ void Transition::initEagerRight() ...@@ -688,7 +684,7 @@ void Transition::initEagerRight()
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
auto govIndex = config.getStack(0); auto govIndex = config.getStack(0);
...@@ -703,7 +699,7 @@ void Transition::initEagerRight() ...@@ -703,7 +699,7 @@ void Transition::initEagerRight()
return cost; return cost;
}; };
costStatic = [](const Config & config) costStatic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto govIndex = config.getStack(0); auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex(); auto depIndex = config.getWordIndex();
...@@ -722,17 +718,12 @@ void Transition::initReduce_strict() ...@@ -722,17 +718,12 @@ void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> & links)
{ {
auto stackIndex = config.getStack(0); if (!config.isToken(config.getStack(0)))
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return 0; return 0;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); return links.at("StackRight");
return cost;
}; };
costStatic = costDynamic; costStatic = costDynamic;
...@@ -743,7 +734,7 @@ void Transition::initGoldReduce_strict() ...@@ -743,7 +734,7 @@ void Transition::initGoldReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto stackIndex = config.getStack(0); auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex(); auto wordIndex = config.getWordIndex();
...@@ -776,7 +767,7 @@ void Transition::initReduce_relaxed() ...@@ -776,7 +767,7 @@ void Transition::initReduce_relaxed()
{ {
sequence.emplace_back(Action::popStack(0)); sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config) costDynamic = [](const Config & config, const std::map<std::string, int> &)
{ {
auto stackIndex = config.getStack(0); auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex(); auto wordIndex = config.getWordIndex();
...@@ -799,7 +790,7 @@ void Transition::initEOS(int bufferIndex) ...@@ -799,7 +790,7 @@ void Transition::initEOS(int bufferIndex)
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack()); sequence.emplace_back(Action::emptyStack());
costDynamic = [bufferIndex](const Config & config) costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1) if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
...@@ -813,7 +804,7 @@ void Transition::initEOS(int bufferIndex) ...@@ -813,7 +804,7 @@ void Transition::initEOS(int bufferIndex)
void Transition::initNotEOS(int bufferIndex) void Transition::initNotEOS(int bufferIndex)
{ {
costDynamic = [bufferIndex](const Config & config) costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1) if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
...@@ -829,7 +820,7 @@ void Transition::initDeprel(std::string label) ...@@ -829,7 +820,7 @@ void Transition::initDeprel(std::string label)
{ {
sequence.emplace_back(Action::deprel(label)); sequence.emplace_back(Action::deprel(label));
costDynamic = [label](const Config & config) costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{ {
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
}; };
...@@ -857,7 +848,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s ...@@ -857,7 +848,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s
toAddUtf8 = util::splitAsUtf8(toAdd); toAddUtf8 = util::splitAsUtf8(toAdd);
sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8)); sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8));
costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config) costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config, const std::map<std::string, int> &)
{ {
int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue); int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue);
int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue); int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue);
...@@ -883,7 +874,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind ...@@ -883,7 +874,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::uppercase(col, objectValue, indexValue)); sequence.emplace_back(Action::uppercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config) costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0); auto & expectedValue = config.getConst(col, lineIndex, 0);
...@@ -908,7 +899,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin ...@@ -908,7 +899,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue)); sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0); auto & expectedValue = config.getConst(col, lineIndex, 0);
...@@ -932,7 +923,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index ...@@ -932,7 +923,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index
auto objectValue = Config::str2object(obj); auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index); int indexValue = std::stoi(index);
costDynamic = [col, objectValue, indexValue](const Config & config) costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0); auto & expectedValue = config.getConst(col, lineIndex, 0);
...@@ -953,7 +944,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind ...@@ -953,7 +944,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::lowercase(col, objectValue, indexValue)); sequence.emplace_back(Action::lowercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config) costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0); auto & expectedValue = config.getConst(col, lineIndex, 0);
...@@ -978,7 +969,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin ...@@ -978,7 +969,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue)); sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
{ {
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0); auto & expectedValue = config.getConst(col, lineIndex, 0);
......
...@@ -40,9 +40,11 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC ...@@ -40,9 +40,11 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC
using Pair = std::pair<Transition*, int>; using Pair = std::pair<Transition*, int>;
std::vector<Pair> appliableTransitions; std::vector<Pair> appliableTransitions;
auto links = computeLinks(c);
for (unsigned int i = 0; i < transitions.size(); i++) for (unsigned int i = 0; i < transitions.size(); i++)
if (transitions[i].appliable(c)) if (transitions[i].appliable(c))
appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c)); appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links));
std::sort(appliableTransitions.begin(), appliableTransitions.end(), std::sort(appliableTransitions.begin(), appliableTransitions.end(),
[](const Pair & a, const Pair & b) [](const Pair & a, const Pair & b)
...@@ -80,12 +82,64 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) ...@@ -80,12 +82,64 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
return result; return result;
} }
std::map<std::string, int> TransitionSet::computeLinks(const Config & c)
{
std::map<std::string, int> links{{"StackRight", 0}, {"BufferRight", 0}, {"BufferRightHead", 0}, {"BufferStack", 0}};
if (c.has(Config::headColName,0,0))
{
int nbLinksStackRight = 0;
int nbLinksBufferRight = 0;
int nbLinksBufferRightHead = 0;
int nbLinksBufferStack = 0;
if (c.hasStack(0))
{
if ((std::size_t)std::stoi(c.getConst(Config::headColName, c.getStack(0), 0)) >= c.getWordIndex())
nbLinksStackRight++;
auto childs = util::split(c.getConst(Config::childsColName, c.getStack(0), 0), '|');
for (auto & child : childs)
{
if ((std::size_t)std::stoi(child) >= c.getWordIndex())
nbLinksStackRight++;
}
}
auto head = c.getConst(Config::headColName, c.getWordIndex(), 0);
if (head != "_" and (std::size_t)std::stoi(c.getConst(Config::headColName, c.getWordIndex(), 0)) > c.getWordIndex())
{
nbLinksBufferRight++;
nbLinksBufferRightHead++;
}
auto childs = util::split(c.getConst(Config::childsColName, c.getWordIndex(), 0), '|');
for (auto & child : childs)
if ((std::size_t)std::stoi(child) > c.getWordIndex())
nbLinksBufferRight++;
auto bufferHead = c.getConst(Config::headColName, c.getWordIndex(), 0);
for (unsigned int i = 0; i < c.getStackSize(); i++)
{
auto stackHead = c.getConst(Config::headColName, c.getStack(i), 0);
if (bufferHead != "_" and stackHead != "_")
if ((std::size_t)std::stoi(bufferHead) == c.getStack(i) or (std::size_t)std::stoi(stackHead) == c.getWordIndex())
nbLinksBufferStack++;
}
links.at("StackRight") = nbLinksStackRight;
links.at("BufferRight") = nbLinksBufferRight;
links.at("BufferRightHead") = nbLinksBufferRightHead;
links.at("BufferStack") = nbLinksBufferStack;
}
return links;
}
std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic) std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
{ {
int bestCost = std::numeric_limits<int>::max(); int bestCost = std::numeric_limits<int>::max();
std::vector<Transition *> result; std::vector<Transition *> result;
std::vector<int> costs(transitions.size()); std::vector<int> costs(transitions.size());
auto links = computeLinks(c);
for (unsigned int i = 0; i < transitions.size(); i++) for (unsigned int i = 0; i < transitions.size(); i++)
{ {
if (!appliableTransitions[i]) if (!appliableTransitions[i])
...@@ -94,7 +148,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi ...@@ -94,7 +148,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
continue; continue;
} }
int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c); int cost = dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links);
costs[i] = cost; costs[i] = cost;
if (cost < bestCost) if (cost < bestCost)
...@@ -104,7 +158,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi ...@@ -104,7 +158,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
for (unsigned int i = 0; i < transitions.size(); i++) for (unsigned int i = 0; i < transitions.size(); i++)
if (costs[i] == bestCost) if (costs[i] == bestCost)
result.emplace_back(&transitions[i]); result.emplace_back(&transitions[i]);
return result; return result;
} }
......
...@@ -117,6 +117,10 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -117,6 +117,10 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
} }
transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
for (auto & trans : goldTransitions)
if (trans == transition)
goldTransition = trans;
} }
else else
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment