Skip to content
Snippets Groups Projects
Commit f0df0508 authored by Franck Dary's avatar Franck Dary
Browse files

Speed up oracle

parent 79837075
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,8 @@ class Transition
std::string name;
std::string state;
std::vector<Action> sequence;
std::function<int(const Config & config)> costDynamic;
std::function<int(const Config & config)> costStatic;
std::function<int(const Config & config, const std::map<std::string, int> & links)> costDynamic;
std::function<int(const Config & config, const std::map<std::string, int> & links)> costStatic;
std::function<bool(const Config & config)> precondition{[](const Config&){return true;}};
private :
......@@ -64,8 +64,8 @@ class Transition
void apply(Config & config, float entropy);
void apply(Config & config);
bool appliable(const Config & config) const;
int getCostDynamic(const Config & config) const;
int getCostStatic(const Config & config) const;
int getCostDynamic(const Config & config, const std::map<std::string, int> & links) const;
int getCostStatic(const Config & config, const std::map<std::string, int> & links) const;
const std::string & getName() const;
};
......
......@@ -28,6 +28,7 @@ class TransitionSet
Transition * getTransition(std::size_t index);
Transition * getTransition(const std::string & name);
std::size_t size() const;
std::map<std::string, int> computeLinks(const Config & c);
};
#endif
......@@ -112,6 +112,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
try
{
std::map<std::string, int> id2index;
std::map<int, std::vector<std::string>> childs;
int firstIndexOfSequence = getNbLines()-1;
for (int i = (int)getNbLines()-1; has(0, i, 0); --i)
{
......@@ -125,6 +126,7 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
id2index[getConst(idColName, i, 0)] = i;
}
if (hasColIndex(headColName))
{
for (int i = firstIndexOfSequence; i < (int)getNbLines(); ++i)
{
if (!isToken(i))
......@@ -133,8 +135,14 @@ void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sent
if (head == "0")
head = "-1";
else
{
childs[id2index[head]].emplace_back(fmt::format("{}",i));
head = std::to_string(id2index[head]);
}
}
for (auto it : childs)
get(Config::childsColName, it.first, 0) = util::join("|", it.second);
}
get(EOSColName, getNbLines()-1, 0) = EOSSymbol1;
} catch(std::exception & e) {util::myThrow(e.what());}
......
......@@ -175,17 +175,17 @@ bool Transition::appliable(const Config & config) const
return true;
}
int Transition::getCostDynamic(const Config & config) const
int Transition::getCostDynamic(const Config & config, const std::map<std::string, int> & links) const
{
try {return costDynamic(config);}
try {return costDynamic(config, links);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
}
int Transition::getCostStatic(const Config & config) const
int Transition::getCostStatic(const Config & config, const std::map<std::string, int> & links) const
{
try {return costStatic(config);}
try {return costStatic(config, links);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
......@@ -203,7 +203,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
......@@ -223,7 +223,7 @@ void Transition::initWriteScore(std::string colName, std::string object, std::st
sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue));
costDynamic = [](const Config &)
costDynamic = [](const Config &, const std::map<std::string, int> &)
{
return 0;
};
......@@ -238,7 +238,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
costDynamic = [colName, objectValue, indexValue, value](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
......@@ -256,7 +256,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
void Transition::initNothing()
{
costDynamic = [](const Config &)
costDynamic = [](const Config &, const std::map<std::string, int> &)
{
return 0;
};
......@@ -268,7 +268,7 @@ void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto goldWord = util::splitAsUtf8(std::string(config.getConst("FORM", config.getWordIndex(), 0)));
......@@ -286,7 +286,7 @@ void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
......@@ -304,7 +304,7 @@ void Transition::initAddCharToWord(int n)
sequence.emplace_back(Action::addCharsToCol("FORM", n, Config::Object::Buffer, 0));
sequence.emplace_back(Action::moveCharacterIndex(n));
costDynamic = [n](const Config & config)
costDynamic = [n](const Config & config, const std::map<std::string, int> &)
{
if (!config.hasCharacter(config.getCharacterIndex()+n-1))
return std::numeric_limits<int>::max();
......@@ -345,7 +345,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
}
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
costDynamic = [words](const Config & config)
costDynamic = [words](const Config & config, const std::map<std::string, int> &)
{
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
......@@ -367,14 +367,14 @@ void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
costDynamic = [index](const Config & config)
costDynamic = [index](const Config & config, const std::map<std::string, int> & links)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCostDynamic(config);
return transitions[index]->getCostDynamic(config, links);
};
costStatic = costDynamic;
......@@ -384,25 +384,22 @@ void Transition::initEagerShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> & links)
{
if (!config.isToken(config.getWordIndex()))
return 0;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
return links.at("BufferStack");
};
costStatic = [](const Config &)
{
return 0;
};
costStatic = costDynamic;
}
void Transition::initGoldEagerShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
if (!config.isToken(config.getWordIndex()))
return 0;
......@@ -410,7 +407,7 @@ void Transition::initGoldEagerShift()
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
};
costStatic = [](const Config &)
costStatic = [](const Config &, const std::map<std::string, int> &)
{
return 0;
};
......@@ -428,7 +425,7 @@ void Transition::initStandardShift()
{
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config &)
costDynamic = [](const Config &, const std::map<std::string, int> &)
{
return 0;
};
......@@ -442,23 +439,23 @@ void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
int cost = 0;
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
if (depGovIndex != std::to_string(govIndex))
++cost;
cost += links.at("StackRight");
return cost;
};
costStatic = [label](const Config & config)
costStatic = [label](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -477,7 +474,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
......@@ -490,7 +487,7 @@ void Transition::initGoldEagerLeft_rel(std::string label)
return cost;
};
costStatic = [label](const Config & config)
costStatic = [label](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -522,7 +519,7 @@ void Transition::initStandardLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1));
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(1);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -548,7 +545,7 @@ void Transition::initEagerLeft()
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(0);
auto govIndex = config.getWordIndex();
......@@ -562,7 +559,7 @@ void Transition::initEagerLeft()
return cost;
};
costStatic = [](const Config & config)
costStatic = [](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -581,24 +578,23 @@ void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> & links)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
int cost = 0;
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
if (depGovIndex == std::to_string(govIndex))
++cost;
if (depGovIndex != std::to_string(govIndex))
cost += links.at("BufferStack") + links.at("BufferRightHead");
return cost;
};
costStatic = [label](const Config & config)
costStatic = [label](const Config & config, const std::map<std::string, int> &)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
......@@ -617,7 +613,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getWordIndex();
......@@ -630,7 +626,7 @@ void Transition::initGoldEagerRight_rel(std::string label)
return cost;
};
costStatic = [label](const Config & config)
costStatic = [label](const Config & config, const std::map<std::string, int> &)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
......@@ -662,7 +658,7 @@ void Transition::initStandardRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{
auto govIndex = config.getStack(1);
auto depIndex = config.getStack(0);
......@@ -688,7 +684,7 @@ void Transition::initEagerRight()
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
auto depIndex = config.getWordIndex();
auto govIndex = config.getStack(0);
......@@ -703,7 +699,7 @@ void Transition::initEagerRight()
return cost;
};
costStatic = [](const Config & config)
costStatic = [](const Config & config, const std::map<std::string, int> &)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
......@@ -722,17 +718,12 @@ void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> & links)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
if (!config.isToken(config.getStack(0)))
return 0;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
return links.at("StackRight");
};
costStatic = costDynamic;
......@@ -743,7 +734,7 @@ void Transition::initGoldReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
......@@ -776,7 +767,7 @@ void Transition::initReduce_relaxed()
{
sequence.emplace_back(Action::popStack(0));
costDynamic = [](const Config & config)
costDynamic = [](const Config & config, const std::map<std::string, int> &)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
......@@ -799,7 +790,7 @@ void Transition::initEOS(int bufferIndex)
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
costDynamic = [bufferIndex](const Config & config)
costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
......@@ -813,7 +804,7 @@ void Transition::initEOS(int bufferIndex)
void Transition::initNotEOS(int bufferIndex)
{
costDynamic = [bufferIndex](const Config & config)
costDynamic = [bufferIndex](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
......@@ -829,7 +820,7 @@ void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
costDynamic = [label](const Config & config)
costDynamic = [label](const Config & config, const std::map<std::string, int> &)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
......@@ -857,7 +848,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s
toAddUtf8 = util::splitAsUtf8(toAdd);
sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8));
costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config)
costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config, const std::map<std::string, int> &)
{
int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue);
int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue);
......@@ -883,7 +874,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::uppercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config)
costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -908,7 +899,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -932,7 +923,7 @@ void Transition::initNothing(std::string col, std::string obj, std::string index
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
costDynamic = [col, objectValue, indexValue](const Config & config)
costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -953,7 +944,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::lowercase(col, objectValue, indexValue));
costDynamic = [col, objectValue, indexValue](const Config & config)
costDynamic = [col, objectValue, indexValue](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -978,7 +969,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue));
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config, const std::map<std::string, int> &)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......
......@@ -40,9 +40,11 @@ std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsC
using Pair = std::pair<Transition*, int>;
std::vector<Pair> appliableTransitions;
auto links = computeLinks(c);
for (unsigned int i = 0; i < transitions.size(); i++)
if (transitions[i].appliable(c))
appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c));
appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links));
std::sort(appliableTransitions.begin(), appliableTransitions.end(),
[](const Pair & a, const Pair & b)
......@@ -80,12 +82,64 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
return result;
}
std::map<std::string, int> TransitionSet::computeLinks(const Config & c)
{
std::map<std::string, int> links{{"StackRight", 0}, {"BufferRight", 0}, {"BufferRightHead", 0}, {"BufferStack", 0}};
if (c.has(Config::headColName,0,0))
{
int nbLinksStackRight = 0;
int nbLinksBufferRight = 0;
int nbLinksBufferRightHead = 0;
int nbLinksBufferStack = 0;
if (c.hasStack(0))
{
if ((std::size_t)std::stoi(c.getConst(Config::headColName, c.getStack(0), 0)) >= c.getWordIndex())
nbLinksStackRight++;
auto childs = util::split(c.getConst(Config::childsColName, c.getStack(0), 0), '|');
for (auto & child : childs)
{
if ((std::size_t)std::stoi(child) >= c.getWordIndex())
nbLinksStackRight++;
}
}
auto head = c.getConst(Config::headColName, c.getWordIndex(), 0);
if (head != "_" and (std::size_t)std::stoi(c.getConst(Config::headColName, c.getWordIndex(), 0)) > c.getWordIndex())
{
nbLinksBufferRight++;
nbLinksBufferRightHead++;
}
auto childs = util::split(c.getConst(Config::childsColName, c.getWordIndex(), 0), '|');
for (auto & child : childs)
if ((std::size_t)std::stoi(child) > c.getWordIndex())
nbLinksBufferRight++;
auto bufferHead = c.getConst(Config::headColName, c.getWordIndex(), 0);
for (unsigned int i = 0; i < c.getStackSize(); i++)
{
auto stackHead = c.getConst(Config::headColName, c.getStack(i), 0);
if (bufferHead != "_" and stackHead != "_")
if ((std::size_t)std::stoi(bufferHead) == c.getStack(i) or (std::size_t)std::stoi(stackHead) == c.getWordIndex())
nbLinksBufferStack++;
}
links.at("StackRight") = nbLinksStackRight;
links.at("BufferRight") = nbLinksBufferRight;
links.at("BufferRightHead") = nbLinksBufferRightHead;
links.at("BufferStack") = nbLinksBufferStack;
}
return links;
}
std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
{
int bestCost = std::numeric_limits<int>::max();
std::vector<Transition *> result;
std::vector<int> costs(transitions.size());
auto links = computeLinks(c);
for (unsigned int i = 0; i < transitions.size(); i++)
{
if (!appliableTransitions[i])
......@@ -94,7 +148,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
continue;
}
int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c);
int cost = dynamic ? transitions[i].getCostDynamic(c, links) : transitions[i].getCostStatic(c, links);
costs[i] = cost;
if (cost < bestCost)
......@@ -104,7 +158,7 @@ std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Confi
for (unsigned int i = 0; i < transitions.size(); i++)
if (costs[i] == bestCost)
result.emplace_back(&transitions[i]);
return result;
}
......
......@@ -117,6 +117,10 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
}
transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]);
for (auto & trans : goldTransitions)
if (trans == transition)
goldTransition = trans;
}
else
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment