Skip to content
Snippets Groups Projects
Commit 65bb83a7 authored by Franck Dary's avatar Franck Dary
Browse files

Only use dynamic oracle and may have corrected a problem

parent 93b372df
Branches
No related tags found
No related merge requests found
......@@ -323,12 +323,8 @@ void Transition::initEagerLeft_rel(std::string label)
costDynamic = [label](const Config & config)
{
auto depIndex = config.getStack(0);
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
auto govIndex = config.getWordIndex();
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(govIndex+1, getLastIndexOfSentence(govIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
......@@ -366,7 +362,7 @@ void Transition::initStandardLeft_rel(std::string label)
return 0;
int cost = getNbLinkedWith(config.getWordIndex(), getLastIndexOfSentence(config.getWordIndex(), config), Config::Object::Buffer, depIndex, config);
cost += getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWith(1, config.getStackSize()-2, Config::Object::Stack, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
++cost;
......@@ -408,14 +404,9 @@ void Transition::initEagerRight_rel(std::string label)
costDynamic = [label](const Config & config)
{
auto govIndex = config.getStack(0);
auto depIndex = config.getWordIndex();
auto depGovIndex = config.getConst(Config::headColName, depIndex, 0);
if (depGovIndex == std::to_string(govIndex) and label == config.getConst(Config::deprelColName, depIndex, 0))
return 0;
int cost = getNbLinkedWith(1, config.getStackSize()-1, Config::Object::Stack, depIndex, config);
int cost = getNbLinkedWith(1, config.getStackSize()-2, Config::Object::Stack, depIndex, config);
cost += getNbLinkedWithHead(depIndex+1, getLastIndexOfSentence(depIndex, config), Config::Object::Buffer, depIndex, config);
if (label != config.getConst(Config::deprelColName, depIndex, 0))
......@@ -704,7 +695,6 @@ void Transition::initLowercaseIndex(std::string col, std::string obj, std::strin
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
......@@ -718,11 +708,10 @@ int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object ob
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
if (govIndex == std::to_string(index))
++nbLinkedWith;
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
if (otherGovIndex == std::to_string(withIndex))
++nbLinkedWith;
}
......@@ -732,7 +721,6 @@ int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object ob
int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
auto govIndexPredicted = config.getAsFeature(Config::headColName, withIndex);
int nbLinkedWith = 0;
......@@ -745,7 +733,7 @@ int Transition::getNbLinkedWithHead(int firstIndex, int lastIndex, Config::Objec
if (!config.isToken(index))
continue;
if (govIndex == std::to_string(index) and util::isEmpty(govIndexPredicted))
if (govIndex == std::to_string(index))
++nbLinkedWith;
}
......@@ -766,9 +754,8 @@ int Transition::getNbLinkedWithDeps(int firstIndex, int lastIndex, Config::Objec
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
auto otherGovIndexPredicted = config.getAsFeature(Config::headColName, index);
if (otherGovIndex == std::to_string(withIndex) and util::isEmpty(otherGovIndexPredicted))
if (otherGovIndex == std::to_string(withIndex))
++nbLinkedWith;
}
......
......@@ -74,7 +74,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
Transition * transition = nullptr;
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, dynamicOracle);
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, true or dynamicOracle);
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment