diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 75f9c40643ae3487ba54ddddd14349e65cb581f5..f0bb97f1aded68bdc22aaa8d4d3c84301ddd2663 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -12,8 +12,8 @@ class Transition std::string name; std::string state; std::vector<Action> sequence; - std::function<int(const Config & config)> costDynamic; - std::function<int(const Config & config)> costStatic; + std::function<int(const Config & config, const std::map<std::string, int> & links)> costDynamic; + std::function<int(const Config & config, const std::map<std::string, int> & links)> costStatic; std::function<bool(const Config & config)> precondition{[](const Config&){return true;}}; private : @@ -64,8 +64,8 @@ class Transition void apply(Config & config, float entropy); void apply(Config & config); bool appliable(const Config & config) const; - int getCostDynamic(const Config & config) const; - int getCostStatic(const Config & config) const; + int getCostDynamic(const Config & config, const std::map<std::string, int> & links) const; + int getCostStatic(const Config & config, const std::map<std::string, int> & links) const; const std::string & getName() const; }; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 936fa88b0e715a7d2b731748101ad784f8ff10ca..21782b032501a3dc5f0cf24e54430fd181a80931 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -28,6 +28,7 @@ class TransitionSet Transition * getTransition(std::size_t index); Transition * getTransition(const std::string & name); std::size_t size() const; + std::map<std::string, int> computeLinks(const Config & c); }; #endif diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 0d373c6e3df2a052ac69a1ded930b3eb6a11ff20..61488ebdcf3016eef811e8f4a9f824755848c2d2 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -112,6 +112,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent try { std::map<std::string, int> id2index; + std::map<int, std::vector<std::string>> childs; int firstIndexOfSequence = getNbLines()-1; for (int i = (int)getNbLines()-1; has(0, i, 0); --i) { @@ -125,6 +126,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent id2index[getConst(idColName, i, 0)] = i; } if (hasColIndex(headColName)) + { for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i) { if (!isToken(i)) @@ -133,8 +135,14 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent if (head == "0") head = "-1"; else + { + childs[id2index[head]].emplace_back(fmt::format("{}",i)); head = std::to_string(id2index[head]); + } } + for (auto it : childs) + get(Config::childsColName, it.first, 0) = util::join("|", it.second); + } get(EOSColName, getNbLines()-1, 0) = EOSSymbol1; } catch(std::exception & e) {util::myThrow(e.what());} diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 858a04457dd4f132feafc3b27b633c7fde4e82ee..aeee6b4fe6a64be6f63640a72ef21dd411bf194c 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -175,17 +175,17 @@ bool Transition::appliable(const Config & config) const return true; } -int Transition::getCostDynamic(const Config & config) const +int Transition::getCostDynamic(const Config & config, const std::map<std::string, int> & links) const { - try {return costDynamic(config);} + try {return costDynamic(config, links);} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} return 0; } -int Transition::getCostStatic(const Config & config) const +int Transition::getCostStatic(const Config & config, const std::map<std::string, int> & links) const { - try {return costStatic(config);} + try {return costStatic(config, links);} catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));} return 0; @@ -203,7 +203,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value)); - costDynamic = [colName, objectValue, indexValue, value](const Config & config) + costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); @@ -223,7 +223,7 @@ void Transition::initWriteScore(std::string colName, std::string object, std::st sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue)); - costDynamic = [](const Config &) + costDynamic = [](const Config &, const std::map<std::string, int> &) { return 0; }; @@ -238,7 +238,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); - costDynamic = [colName, objectValue, indexValue, value](const Config & config) + costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); @@ -256,7 +256,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in void Transition::initNothing() { - costDynamic = [](const Config &) + costDynamic = [](const Config &, const std::map<std::string, int> &) { return 0; }; @@ -268,7 +268,7 @@ void Transition::initIgnoreChar() { sequence.emplace_back(Action::ignoreCurrentCharacter()); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex())); auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0))); @@ -286,7 +286,7 @@ void Transition::initEndWord() { sequence.emplace_back(Action::endWord()); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex())) return 0; @@ -304,7 +304,7 @@ void Transition::initAddCharToWord(int n) sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0)); sequence.emplace_back(Action::moveCharacterIndex(n)); - costDynamic = [n](const Config & config) + costDynamic = [n](const Config & config, const std::map<std::string, int> &) { if (!config.hasCharacter(config.getCharacterIndex()+n-1)) return std::numeric_limits<int>::max(); @@ -345,7 +345,7 @@ void Transition::initSplitWord(std::vector<std::string> words) } sequence.emplace_back(Action::setMultiwordIds(words.size()-1)); - costDynamic = [words](const Config & config) + costDynamic = [words](const Config & config, const std::map<std::string, int> &) { if (!config.isMultiword(config.getWordIndex())) return std::numeric_limits<int>::max(); @@ -367,14 +367,14 @@ void Transition::initSplit(int index) { sequence.emplace_back(Action::split(index)); - costDynamic = [index](const Config & config) + costDynamic = [index](const Config & config, const std::map<std::string, int> & links) { auto & transitions = config.getAppliableSplitTransitions(); if (index < 0 or index >= (int)transitions.size()) return std::numeric_limits<int>::max(); - return transitions[index]->getCostDynamic(config); + return transitions[index]->getCostDynamic(config, links); }; costStatic = costDynamic; @@ -384,25 +384,22 @@ void Transition::initEagerShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> & links) { if (!config.isToken(config.getWordIndex())) return 0; - return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); + return links.at("BufferStack"); }; - costStatic = [](const Config &) - { - return 0; - }; + costStatic = costDynamic; } void Transition::initGoldEagerShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { if (!config.isToken(config.getWordIndex())) return 0; @@ -410,7 +407,7 @@ void Transition::initGoldEagerShift() return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config); }; - costStatic = [](const Config &) + costStatic = [](const Config &, const std::map<std::string, int> &) { return 0; }; @@ -428,7 +425,7 @@ void Transition::initStandardShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [](const Config &) + costDynamic = [](const Config &, const std::map<std::string, int> &) { return 0; }; @@ -442,23 +439,23 @@ void Transition::initEagerLeft_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> & links) { auto depIndex = config.getStack(0); auto govIndex = config.getWordIndex(); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config); + int cost = 0; if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; if (depGovIndex != std::to_string(govIndex)) - ++cost; + cost += links.at("StackRight"); return cost; }; - costStatic = [label](const Config & config) + costStatic = [label](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -477,7 +474,7 @@ void Transition::initGoldEagerLeft_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(0); auto govIndex = config.getWordIndex(); @@ -490,7 +487,7 @@ void Transition::initGoldEagerLeft_rel(std::string label) return cost; }; - costStatic = [label](const Config & config) + costStatic = [label](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -522,7 +519,7 @@ void Transition::initStandardLeft_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label)); sequence.emplace_back(Action::popStack(1)); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(1); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -548,7 +545,7 @@ void Transition::initEagerLeft() sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(0); auto govIndex = config.getWordIndex(); @@ -562,7 +559,7 @@ void Transition::initEagerLeft() return cost; }; - costStatic = [](const Config & config) + costStatic = [](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getStack(0); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); @@ -581,24 +578,23 @@ void Transition::initEagerRight_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> & links) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); auto depGovIndex = config.getConst(Config::headColName, depIndex, 0); - int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config); - cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config); + int cost = 0; if (label != config.getConst(Config::deprelColName, depIndex, 0)) ++cost; - if (depGovIndex == std::to_string(govIndex)) - ++cost; + if (depGovIndex != std::to_string(govIndex)) + cost += links.at("BufferStack") + links.at("BufferRightHead"); return cost; }; - costStatic = [label](const Config & config) + costStatic = [label](const Config & config, const std::map<std::string, int> &) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); @@ -617,7 +613,7 @@ void Transition::initGoldEagerRight_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getWordIndex(); @@ -630,7 +626,7 @@ void Transition::initGoldEagerRight_rel(std::string label) return cost; }; - costStatic = [label](const Config & config) + costStatic = [label](const Config & config, const std::map<std::string, int> &) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); @@ -662,7 +658,7 @@ void Transition::initStandardRight_rel(std::string label) sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> &) { auto govIndex = config.getStack(1); auto depIndex = config.getStack(0); @@ -688,7 +684,7 @@ void Transition::initEagerRight() sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::pushWordIndexOnStack()); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { auto depIndex = config.getWordIndex(); auto govIndex = config.getStack(0); @@ -703,7 +699,7 @@ void Transition::initEagerRight() return cost; }; - costStatic = [](const Config & config) + costStatic = [](const Config & config, const std::map<std::string, int> &) { auto govIndex = config.getStack(0); auto depIndex = config.getWordIndex(); @@ -722,17 +718,12 @@ void Transition::initReduce_strict() sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> & links) { - auto stackIndex = config.getStack(0); - auto wordIndex = config.getWordIndex(); - - if (!config.isToken(stackIndex)) + if (!config.isToken(config.getStack(0))) return 0; - int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config); - - return cost; + return links.at("StackRight"); }; costStatic = costDynamic; @@ -743,7 +734,7 @@ void Transition::initGoldReduce_strict() sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack(0)); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); @@ -776,7 +767,7 @@ void Transition::initReduce_relaxed() { sequence.emplace_back(Action::popStack(0)); - costDynamic = [](const Config & config) + costDynamic = [](const Config & config, const std::map<std::string, int> &) { auto stackIndex = config.getStack(0); auto wordIndex = config.getWordIndex(); @@ -799,7 +790,7 @@ void Transition::initEOS(int bufferIndex) sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); sequence.emplace_back(Action::emptyStack()); - costDynamic = [bufferIndex](const Config & config) + costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1) @@ -813,7 +804,7 @@ void Transition::initEOS(int bufferIndex) void Transition::initNotEOS(int bufferIndex) { - costDynamic = [bufferIndex](const Config & config) + costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex); if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1) @@ -829,7 +820,7 @@ void Transition::initDeprel(std::string label) { sequence.emplace_back(Action::deprel(label)); - costDynamic = [label](const Config & config) + costDynamic = [label](const Config & config, const std::map<std::string, int> &) { return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1; }; @@ -857,7 +848,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s toAddUtf8 = util::splitAsUtf8(toAdd); sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8)); - costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config) + costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config, const std::map<std::string, int> &) { int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue); int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue); @@ -883,7 +874,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind sequence.emplace_back(Action::uppercase(col, objectValue, indexValue)); - costDynamic = [col, objectValue, indexValue](const Config & config) + costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -908,7 +899,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue)); - costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) + costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -932,7 +923,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index auto objectValue = Config::str2object(obj); int indexValue = std::stoi(index); - costDynamic = [col, objectValue, indexValue](const Config & config) + costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -953,7 +944,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind sequence.emplace_back(Action::lowercase(col, objectValue, indexValue)); - costDynamic = [col, objectValue, indexValue](const Config & config) + costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); @@ -978,7 +969,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue)); - costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config) + costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &) { int lineIndex = config.getRelativeWordIndex(objectValue, indexValue); auto & expectedValue = config.getConst(col, lineIndex, 0); diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 8f146e7f04e39e1bc15495dbc1e35107fcd43799..63590baa43484fa37aa340c682016fd1ddae4a58 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -40,9 +40,11 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC using Pair = std::pair<Transition*, int>; std::vector<Pair> appliableTransitions; + auto links = computeLinks(c); + for (unsigned int i = 0; i < transitions.size(); i++) if (transitions[i].appliable(c)) - appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c)); + appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links)); std::sort(appliableTransitions.begin(), appliableTransitions.end(), [](const Pair & a, const Pair & b) @@ -80,12 +82,64 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) return result; } +std::map<std::string, int> TransitionSet::computeLinks(const Config & c) +{ + std::map<std::string, int> links{{"StackRight", 0}, {"BufferRight", 0}, {"BufferRightHead", 0}, {"BufferStack", 0}}; + + if (c.has(Config::headColName,0,0)) + { + int nbLinksStackRight = 0; + int nbLinksBufferRight = 0; + int nbLinksBufferRightHead = 0; + int nbLinksBufferStack = 0; + if (c.hasStack(0)) + { + if ((std::size_t)std::stoi(c.getConst(Config::headColName, c.getStack(0), 0)) >= c.getWordIndex()) + nbLinksStackRight++; + auto childs = util::split(c.getConst(Config::childsColName, c.getStack(0), 0), '|'); + for (auto & child : childs) + { + if ((std::size_t)std::stoi(child) >= c.getWordIndex()) + nbLinksStackRight++; + } + } + + auto head = c.getConst(Config::headColName, c.getWordIndex(), 0); + if (head != "_" and (std::size_t)std::stoi(c.getConst(Config::headColName, c.getWordIndex(), 0)) > c.getWordIndex()) + { + nbLinksBufferRight++; + nbLinksBufferRightHead++; + } + auto childs = util::split(c.getConst(Config::childsColName, c.getWordIndex(), 0), '|'); + for (auto & child : childs) + if ((std::size_t)std::stoi(child) > c.getWordIndex()) + nbLinksBufferRight++; + auto bufferHead = c.getConst(Config::headColName, c.getWordIndex(), 0); + for (unsigned int i = 0; i < c.getStackSize(); i++) + { + auto stackHead = c.getConst(Config::headColName, c.getStack(i), 0); + if (bufferHead != "_" and stackHead != "_") + if ((std::size_t)std::stoi(bufferHead) == c.getStack(i) or (std::size_t)std::stoi(stackHead) == c.getWordIndex()) + nbLinksBufferStack++; + } + + links.at("StackRight") = nbLinksStackRight; + links.at("BufferRight") = nbLinksBufferRight; + links.at("BufferRightHead") = nbLinksBufferRightHead; + links.at("BufferStack") = nbLinksBufferStack; + } + + return links; +} + std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic) { int bestCost = std::numeric_limits<int>::max(); std::vector<Transition *> result; std::vector<int> costs(transitions.size()); + auto links = computeLinks(c); + for (unsigned int i = 0; i < transitions.size(); i++) { if (!appliableTransitions[i]) @@ -94,7 +148,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi continue; } - int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c); + int cost = dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links); costs[i] = cost; if (cost < bestCost) @@ -104,7 +158,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi for (unsigned int i = 0; i < transitions.size(); i++) if (costs[i] == bestCost) result.emplace_back(&transitions[i]); - + return result; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a0d017248c25a453003aee653e65a5900b7dbb83..c65b72adb0b6b1f36e7c1e4c688071d332d71c04 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -117,6 +117,10 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: } transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); + + for (auto & trans : goldTransitions) + if (trans == transition) + goldTransition = trans; } else {