Skip to content
Snippets Groups Projects
Commit 74a4a5f6 authored by Franck Dary's avatar Franck Dary
Browse files

Introduced static oracle to gain speed when dynamic oracle is not mandatory

parent 95131d6d
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,8 @@ class Transition
std::string name;
std::string state;
std::vector<Action> sequence;
std::function<int(const Config & config)> cost;
std::function<int(const Config & config)> costDynamic;
std::function<int(const Config & config)> costStatic;
private :
......@@ -54,7 +55,8 @@ class Transition
Transition(const std::string & name);
void apply(Config & config);
bool appliable(const Config & config) const;
int getCost(const Config & config) const;
int getCostDynamic(const Config & config) const;
int getCostStatic(const Config & config) const;
const std::string & getName() const;
};
......
......@@ -20,8 +20,8 @@ class TransitionSet
TransitionSet(const std::vector<std::string> & filenames);
TransitionSet(const std::string & filename);
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c);
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c, bool dynamic = false);
Transition * getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false);
std::vector<Transition *> getNAppliableTransitions(const Config & c, int n);
std::vector<int> getAppliableTransitions(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const;
......
......@@ -107,9 +107,17 @@ bool Transition::appliable(const Config & config) const
return true;
}
int Transition::getCost(const Config & config) const
int Transition::getCostDynamic(const Config & config) const
{
try {return cost(config);}
try {return costDynamic(config);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
}
int Transition::getCostStatic(const Config & config) const
{
try {return costStatic(config);}
catch (std::exception & e) { util::myThrow(fmt::format("transition '{}' {}", name, e.what()));}
return 0;
......@@ -127,7 +135,7 @@ void Transition::initWrite(std::string colName, std::string object, std::string
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
......@@ -145,7 +153,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
costDynamic = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
......@@ -161,7 +169,7 @@ void Transition::initAdd(std::string colName, std::string object, std::string in
void Transition::initNothing()
{
cost = [](const Config &)
costDynamic = [](const Config &)
{
return 0;
};
......@@ -171,7 +179,7 @@ void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
costDynamic = [](const Config &)
{
return 0;
};
......@@ -181,7 +189,7 @@ void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
......@@ -196,7 +204,7 @@ void Transition::initAddCharToWord()
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
......@@ -226,7 +234,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
sequence.emplace_back(Action::addHypothesisRelativeRelaxed("FORM", Config::Object::Buffer, i, words[i]));
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
costDynamic = [words](const Config & config)
{
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
......@@ -247,14 +255,14 @@ void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
costDynamic = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
return transitions[index]->getCostDynamic(config);
};
}
......@@ -263,13 +271,18 @@ void Transition::initEagerShift()
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
if (!config.isToken(config.getWordIndex()))
return 0;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
};
costStatic = [](const Config &)
{
return 0;
};
}
void Transition::initStandardShift()
......@@ -277,7 +290,7 @@ void Transition::initStandardShift()
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
return 0;
};
......@@ -289,7 +302,7 @@ void Transition::initEagerLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
cost = [label](const Config & config)
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -305,6 +318,18 @@ void Transition::initEagerLeft_rel(std::string label)
return cost;
};
costStatic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
}
void Transition::initStandardLeft_rel(std::string label)
......@@ -313,7 +338,7 @@ void Transition::initStandardLeft_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 1, label));
sequence.emplace_back(Action::popStack(1));
cost = [label](const Config & config)
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(1);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -337,7 +362,7 @@ void Transition::initEagerLeft()
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
......@@ -359,7 +384,7 @@ void Transition::initEagerRight_rel(std::string label)
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [label](const Config & config)
costDynamic = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
......@@ -376,6 +401,18 @@ void Transition::initEagerRight_rel(std::string label)
return cost;
};
costStatic = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
return 1;
};
}
void Transition::initStandardRight_rel(std::string label)
......@@ -384,7 +421,7 @@ void Transition::initStandardRight_rel(std::string label)
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack(0));
cost = [label](const Config & config)
costDynamic = [label](const Config & config)
{
auto govIndex = config.getStack(1);
auto depIndex = config.getStack(0);
......@@ -409,7 +446,7 @@ void Transition::initEagerRight()
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
......@@ -430,7 +467,20 @@ void Transition::initReduce_strict()
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!config.isToken(stackIndex))
return 0;
int cost = getNbLinkedWithDeps(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
return cost;
};
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
......@@ -442,13 +492,15 @@ void Transition::initReduce_strict()
return cost;
};
costStatic = costDynamic;
}
void Transition::initReduce_relaxed()
{
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
costDynamic = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
......@@ -469,7 +521,7 @@ void Transition::initEOS(int bufferIndex)
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
cost = [bufferIndex](const Config & config)
costDynamic = [bufferIndex](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
......@@ -483,7 +535,7 @@ void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
costDynamic = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
......@@ -509,7 +561,7 @@ void Transition::initTransformSuffix(std::string fromCol, std::string fromObj, s
toAddUtf8 = util::splitAsUtf8(toAdd);
sequence.emplace_back(Action::transformSuffix(fromCol, fromObjectValue, fromIndexValue, toCol, toObjectValue, toIndexValue, toRemoveUtf8, toAddUtf8));
cost = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config)
costDynamic = [fromObjectValue, fromIndexValue, toObjectValue, toIndexValue, toAddUtf8, toRemoveUtf8, fromCol, toCol](const Config & config)
{
int fromLineIndex = config.getRelativeWordIndex(fromObjectValue, fromIndexValue);
int toLineIndex = config.getRelativeWordIndex(toObjectValue, toIndexValue);
......@@ -533,7 +585,7 @@ void Transition::initUppercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::uppercase(col, objectValue, indexValue));
cost = [col, objectValue, indexValue](const Config & config)
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -556,7 +608,7 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::uppercaseIndex(col, objectValue, indexValue, inIndexValue));
cost = [col, objectValue, indexValue, inIndexValue](const Config & config)
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -580,7 +632,7 @@ void Transition::initLowercase(std::string col, std::string obj, std::string ind
sequence.emplace_back(Action::lowercase(col, objectValue, indexValue));
cost = [col, objectValue, indexValue](const Config & config)
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......@@ -603,7 +655,7 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin
sequence.emplace_back(Action::lowercaseIndex(col, objectValue, indexValue, inIndexValue));
cost = [col, objectValue, indexValue, inIndexValue](const Config & config)
costDynamic = [col, objectValue, indexValue, inIndexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
......
......@@ -35,14 +35,14 @@ void TransitionSet::addTransitionsFromFile(const std::string & filename)
std::fclose(file);
}
std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c)
std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c, bool dynamic)
{
using Pair = std::pair<Transition*, int>;
std::vector<Pair> appliableTransitions;
for (unsigned int i = 0; i < transitions.size(); i++)
if (transitions[i].appliable(c))
appliableTransitions.emplace_back(&transitions[i], transitions[i].getCost(c));
appliableTransitions.emplace_back(&transitions[i], dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c));
std::sort(appliableTransitions.begin(), appliableTransitions.end(),
[](const Pair & a, const Pair & b)
......@@ -80,17 +80,17 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c)
return result;
}
Transition * TransitionSet::getBestAppliableTransition(const Config & c)
Transition * TransitionSet::getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic)
{
Transition * result = nullptr;
int bestCost = std::numeric_limits<int>::max();
for (unsigned int i = 0; i < transitions.size(); i++)
{
if (!transitions[i].appliable(c))
if (!appliableTransitions[i])
continue;
int cost = transitions[i].getCost(c);
int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c);
if (cost == 0)
return &transitions[i];
......
......@@ -73,7 +73,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
Transition * transition = nullptr;
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle);
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
......@@ -301,7 +301,7 @@ void Trainer::fillDicts(SubConfig & config, bool debug)
}
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions);
if (!goldTransition)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment