Skip to content
Snippets Groups Projects
Commit 3c41224b authored by Franck Dary's avatar Franck Dary
Browse files

Explore different transitions only for parser

parent a4487b91
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,7 @@ class Transition ...@@ -43,6 +43,7 @@ class Transition
void initGoldReduce_strict(); void initGoldReduce_strict();
void initReduce_relaxed(); void initReduce_relaxed();
void initEOS(int bufferIndex); void initEOS(int bufferIndex);
void initNotEOS(int bufferIndex);
void initNothing(); void initNothing();
void initIgnoreChar(); void initIgnoreChar();
void initEndWord(); void initEndWord();
...@@ -52,6 +53,7 @@ class Transition ...@@ -52,6 +53,7 @@ class Transition
void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule); void initTransformSuffix(std::string fromCol, std::string fromObj, std::string fromIndex, std::string toCol, std::string toObj, std::string toIndex, std::string rule);
void initUppercase(std::string col, std::string obj, std::string index); void initUppercase(std::string col, std::string obj, std::string index);
void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); void initUppercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex);
void initNothing(std::string col, std::string obj, std::string index);
void initLowercase(std::string col, std::string obj, std::string index); void initLowercase(std::string col, std::string obj, std::string index);
void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex);
......
...@@ -43,6 +43,8 @@ Transition::Transition(const std::string & name) ...@@ -43,6 +43,8 @@ Transition::Transition(const std::string & name)
[this](auto sm){initEOS(std::stoi(sm[1]));}}, [this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"), {std::regex("NOTHING"),
[this](auto){initNothing();}}, [this](auto){initNothing();}},
{std::regex("NOTEOS b\\.(.+)"),
[this](auto sm){initNotEOS(std::stoi(sm[1]));}},
{std::regex("IGNORECHAR"), {std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}}, [this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"), {std::regex("ENDWORD"),
...@@ -57,6 +59,8 @@ Transition::Transition(const std::string & name) ...@@ -57,6 +59,8 @@ Transition::Transition(const std::string & name)
[this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}}, [this](auto sm){(initUppercase(sm[1], sm[2], sm[3]));}},
{std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), {std::regex("UPPERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
[this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}}, [this](auto sm){(initUppercaseIndex(sm[1], sm[2], sm[3], sm[4]));}},
{std::regex("NOTHING (.+) ([bs])\\.(.+)"),
[this](auto sm){(initNothing(sm[1], sm[2], sm[3]));}},
{std::regex("LOWERCASE (.+) ([bs])\\.(.+)"), {std::regex("LOWERCASE (.+) ([bs])\\.(.+)"),
[this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}}, [this](auto sm){(initLowercase(sm[1], sm[2], sm[3]));}},
{std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"), {std::regex("LOWERCASEINDEX (.+) ([bs])\\.(.+) (.+)"),
...@@ -713,6 +717,20 @@ void Transition::initEOS(int bufferIndex) ...@@ -713,6 +717,20 @@ void Transition::initEOS(int bufferIndex)
costStatic = costDynamic; costStatic = costDynamic;
} }
void Transition::initNotEOS(int bufferIndex)
{
costDynamic = [bufferIndex](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
return 0;
};
costStatic = costDynamic;
}
void Transition::initDeprel(std::string label) void Transition::initDeprel(std::string label)
{ {
sequence.emplace_back(Action::deprel(label)); sequence.emplace_back(Action::deprel(label));
...@@ -815,6 +833,25 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin ...@@ -815,6 +833,25 @@ void Transition::initUppercaseIndex(std::string col, std::string obj, std::strin
costStatic = costDynamic; costStatic = costDynamic;
} }
void Transition::initNothing(std::string col, std::string obj, std::string index)
{
auto objectValue = Config::str2object(obj);
int indexValue = std::stoi(index);
costDynamic = [col, objectValue, indexValue](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto & expectedValue = config.getConst(col, lineIndex, 0);
std::string currentValue = config.getAsFeature(col, lineIndex).get();
if (expectedValue == currentValue)
return 0;
return 1;
};
costStatic = costDynamic;
}
void Transition::initLowercase(std::string col, std::string obj, std::string index) void Transition::initLowercase(std::string col, std::string obj, std::string index)
{ {
auto objectValue = Config::str2object(obj); auto objectValue = Config::str2object(obj);
......
...@@ -128,7 +128,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -128,7 +128,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
Transition * transition = nullptr; Transition * transition = nullptr;
auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);
Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()];
Transition * goldTransition = goldTransitions[0];
if (config.getState() == "parser")
goldTransitions[std::rand()%goldTransitions.size()];
int nbClasses = machine.getTransitionSet(config.getState()).size(); int nbClasses = machine.getTransitionSet(config.getState()).size();
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment