Something went wrong on our end
Select Git revision
-
Franck Dary authoredFranck Dary authored
Oracle.cpp 42.84 KiB
/*Copyright (c) 2019 Alexis Nasr && Franck Dary
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:i
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.*/
#include "Oracle.hpp"
#include "util.hpp"
#include "File.hpp"
#include "Action.hpp"
#include "ProgramParameters.hpp"
std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle;
Oracle::Oracle(std::function<void(Oracle *)> initialize,
std::function<std::string(Config &, Oracle *)> findAction,
std::function<int(Config &, Oracle *, const std::string &)> getCost)
{
this->getCost = getCost;
this->findAction = findAction;
this->initialize = initialize;
this->isInit = false;
}
Oracle * Oracle::getOracle(const std::string & name, std::string filename)
{
createDatabase();
auto it = str2oracle.find(name);
if(it != str2oracle.end())
{
if(!filename.empty())
it->second->setFilename(filename);
return it->second.get();
}
fprintf(stderr, "ERROR (%s) : invalid oracle name \'%s\'. Aborting.\n", ERRINFO, name.c_str());
exit(1);
return nullptr;
}
Oracle * Oracle::getOracle(const std::string & name)
{
return getOracle(name, "");
}
int Oracle::getActionCost(Config & config, const std::string & action)
{
if(!isInit)
init();
return getCost(config, this, action);
}
std::string Oracle::getAction(Config & config)
{
if(!isInit)
init();
return findAction(config, this);
}
void Oracle::init()
{
isInit = true;
initialize(this);
}
void Oracle::setFilename(const std::string & filename)
{
this->filename = filename;
}
void Oracle::createDatabase()
{
static bool isInit = false;
if(isInit)
return;
isInit = true;
str2oracle.emplace("null", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config &, Oracle *, const std::string &)
{
fprintf(stderr, "ERROR (%s) : getAction called on null Oracle. Aborting.\n", ERRINFO);
exit(1);
return 1;
})));
auto backSys = [](Config & c, Oracle * oracle)
{
if (oracle->data.count("systematic"))
{
if (Action("BACK " + oracle->data["systematic"]).appliable(c))
return std::string("BACK " + oracle->data["systematic"]);
return std::string("EPSILON");
}
return std::string("EPSILON");
};
str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
while (fscanf(fd, "%[^\n]\n", b1) == 1)
{
auto line = util::split(b1);
if (line.size() == 2)
oracle->data[line[0]] = line[1];
else
{
fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1);
exit(1);
}
}
},
backSys,
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
while (fscanf(fd, "%[^\n]\n", b1) == 1)
{
auto line = util::split(b1);
if (line.size() == 2)
oracle->data[line[0]] = line[1];
else
{
fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1);
exit(1);
}
}
},
backSys,
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
while (fscanf(fd, "%[^\n]\n", b1) == 1)
{
auto line = util::split(b1);
if (line.size() == 2)
oracle->data[line[0]] = line[1];
else
{
fprintf(stderr, "ERROR (%s) : Invalid line \'%s\'. Aborting.\n", ERRINFO, b1);
exit(1);
}
}
},
backSys,
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("segmenter", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
if (c.getTape(ProgramParameters::sequenceDelimiterTape).getRef(0) == ProgramParameters::sequenceDelimiter)
return action == "EOS b.0" ? 0 : 1;
return action != "EOS b.0" ? 0 : 1;
})));
str2oracle.emplace("tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 POS " + c.getTape("POS").getRef(0) ? 0 : 1;
})));
str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
auto & currentWordRef = c.getTape("FORM").getRef(0);
auto & currentWordHyp = c.getTape("FORM").getHyp(0);
auto splited = util::split(util::split(action, ' ').back(),'@');
if (splited.size() > 2)
{
if (c.rawInput.begin() + splited[0].size() >= c.rawInput.end())
return 1;
for (unsigned int i = 0; i < splited[0].size(); i++)
if (splited[0][i] != c.rawInput[c.rawInputHeadIndex+i])
return 1;
for (unsigned int i = 0; i < splited.size(); i++)
if (c.getTape("FORM").getRef(i) != splited[i])
return 1;
return 0;
}
if (currentWordRef == currentWordHyp)
if (action == "ENDWORD")
return 0;
if (action == "ADDCHARTOWORD" && currentWordRef.size() > currentWordHyp.size())
{
if (c.hasTape("ID") && c.getTape("ID").getRef(0).empty())
{
fprintf(stderr, "ERROR (%s) : ID.getRef(0) is empty. Aborting.\n", ERRINFO);
exit(1);
}
if (c.hasTape("ID") && util::split(c.getTape("ID").getRef(0), '-').size() > 1)
return 1;
for (unsigned int i = 0; i < (currentWordRef.size()-currentWordHyp.size()); i++)
if (currentWordRef[currentWordHyp.size()+i] != c.rawInput[c.rawInputHeadIndex+i])
return 1;
return 0;
}
return 1;
})));
str2oracle.emplace("eos", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).getRef(0) == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0"));
})));
str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
if (!strncmp("WRITE", action.c_str(), 5))
return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) ? 0 : 1;
auto & hypMorpho = c.getTape("MORPHO").getHyp(0);
auto & refMorpho = c.getTape("MORPHO").getRef(0);
auto partsRef = refMorpho.empty() ? std::vector<std::string>() : util::split(refMorpho, '|');
auto partsHyp = hypMorpho.empty() ? std::vector<std::string>() : util::split(hypMorpho, '|');
if (!strncmp("NOTHING", action.c_str(), 7))
{
int diff = std::abs((int)(partsRef.size()-partsHyp.size()));
return partsRef == partsHyp ? 0 : diff;
}
if (strncmp("ADD", action.c_str(), 3))
return 1;
auto actionPart = util::split(action, ' ').back();
std::set<std::string> presentHyp;
std::set<std::string> presentRef;
for (auto & part : partsHyp)
presentHyp.insert(part);
for (auto & part : partsRef)
presentRef.insert(part);
int cost = 0;
if (!presentRef.count(actionPart))
cost++;
if (presentHyp.count(actionPart))
cost++;
return cost;
})));
str2oracle.emplace("strategy_morpho_whole", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
return std::string("MOVE morpho 1");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_morpho_parts", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
if (previousAction == "nothing")
return std::string("MOVE morpho 1");
return std::string("MOVE morpho 0");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_segmenter", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE segmenter 0");
return std::string("MOVE segmenter 1");
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tagger 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "tagger" || previousState == "error_tagger")
{
newState = "tagger";
movement = 1;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_lemmatizer", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE lemmatizer_lookup 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "lemmatizer_rules")
{
newState = "lemmatizer_case";
movement = 0;
}
else if (previousState == "lemmatizer_lookup")
{
newState = "lemmatizer_case";
movement = 0;
if (previousAction == "notfound")
newState = "lemmatizer_rules";
}
else
{
newState = "lemmatizer_lookup";
movement = 1;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tokenizer", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tokenizer 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState = "tokenizer";
int movement = 0;
if (util::split(previousAction, ' ')[0] == "splitword")
{
int nbSplit = util::split(util::split(previousAction, ' ')[1], '@').size();
movement = nbSplit;
}
else if (previousAction == "endword")
movement = 1;
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tokenizer,tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tokenizer 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "tokenizer")
{
if (util::split(previousAction, ' ')[0] == "splitword" || util::split(previousAction, ' ')[0] == "endword")
newState = "tagger";
else
newState = "tokenizer";
if (util::split(previousAction, ' ')[0] == "splitword")
movement = 1;
}
else if (previousState == "tagger" || previousState == "error_tagger")
{
newState = "tokenizer";
movement = 1;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_parser_legacy", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE parser 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "parser")
{
newState = "parser";
if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
movement = 1;
if (movement > 0 && c.endOfTapes())
movement = 0;
if (util::split(previousAction, ' ')[0] == "eos" && c.endOfTapes())
return std::string("");
}
else if (previousState == "error_parser")
{
newState = "parser";
std::string previousParserAction = util::noAccentLower(c.pastActions.getElem(1).second.name);
if (util::split(previousParserAction, ' ')[0] == "shift" || util::split(previousParserAction, ' ')[0] == "right")
movement = 1;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE parser 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "parser")
{
newState = "parser";
if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
newState = "segmenter";
}
else if (previousState == "segmenter")
{
newState = "parser";
movement = 1;
}
else if (previousState == "error_parser")
{
newState = "parser";
std::string previousParserAction = util::noAccentLower(c.pastActions.getElem(1).second.name);
if (util::split(previousParserAction, ' ')[0] == "shift" || util::split(previousParserAction, ' ')[0] == "right")
movement = 1;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tagger 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "tagger")
newState = "morpho";
else if (previousState == "morpho")
{
newState = "morpho";
if (previousAction == "nothing")
newState = "lemmatizer_lookup";
}
else if (previousState == "lemmatizer_lookup")
{
if (previousAction == "notfound")
newState = "lemmatizer_rules";
else
newState = "lemmatizer_case";
}
else if (previousState == "lemmatizer_rules")
newState = "lemmatizer_case";
else if (previousState == "lemmatizer_case")
newState = "parser";
else if (previousState == "parser")
{
if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
{
newState = "segmenter";
movement = 0;
}
else
newState = "parser";
}
else if (previousState == "segmenter")
{
newState = "tagger";
movement = 1;
}
else
newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tagger,morpho,lemmatizer,parser_sequential", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tagger 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
static constexpr int lookahead = 2;
static std::map<std::string,int> done{{"tagger",0},{"morpho",0},{"lemmatizer_case",0},{"parser",0}};
static std::map<std::string,int> lastIndexDone{{"tagger",-1},{"morpho",-1},{"lemmatizer_case",-1},{"parser",-1}};
static std::map<std::string,int> todo{{"tagger",3*lookahead+1},{"morpho",2*lookahead+1},{"lemmatizer_case",lookahead+1}};
if (previousState == "tagger")
{
done[previousState]++;
lastIndexDone[previousState] = c.getHead();
if (done[previousState] != todo[previousState])
{
newState = "tagger";
movement = 1;
}
else
{
done[previousState] = 0;
newState = "morpho";
movement = lastIndexDone[newState]-c.getHead()+1;
}
}
else if (previousState == "morpho")
{
newState = "morpho";
if (previousAction == "nothing")
{
done[previousState]++;
lastIndexDone[previousState] = c.getHead();
if (done[previousState] != todo[previousState])
{
newState = "morpho";
movement = 1;
}
else
{
done[previousState] = 0;
newState = "lemmatizer_lookup";
movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1;
}
}
}
else if (previousState == "lemmatizer_lookup")
{
if (previousAction == "notfound")
newState = "lemmatizer_rules";
else
newState = "lemmatizer_case";
}
else if (previousState == "lemmatizer_rules")
newState = "lemmatizer_case";
else if (previousState == "lemmatizer_case")
{
newState = "parser";
done[previousState]++;
lastIndexDone[previousState] = c.getHead();
if (done[previousState] != todo[previousState])
{
newState = "lemmatizer_rules";
movement = 1;
}
else
{
newState = "parser";
done[previousState] = 0;
movement = lastIndexDone[newState]-c.getHead()+1;
}
}
else if (previousState == "parser")
{
if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
{
newState = "segmenter";
movement = 0;
lastIndexDone[previousState] = c.getHead();
}
else
newState = "parser";
}
else if (previousState == "segmenter")
{
newState = "tagger";
movement = lastIndexDone[newState]-c.getHead()+1;
if (lastIndexDone[newState]+1 >= c.getTape("FORM").size())
{
newState = "morpho";
movement = lastIndexDone[newState]-c.getHead()+1;
if (lastIndexDone[newState]+1 >= c.getTape("FORM").size())
{
newState = "lemmatizer_rules";
movement = lastIndexDone["lemmatizer_case"]-c.getHead()+1;
if (lastIndexDone["lemmatizer_case"]+1 >= c.getTape("FORM").size())
{
newState = "parser";
movement = lastIndexDone[newState]-c.getHead()+1;
}
}
}
todo["tagger"] = 1;
todo["morpho"] = 1;
todo["lemmatizer_case"] = 1;
}
else
newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
if (c.isFinal())
{
done = {{"tagger",0},{"morpho",0},{"lemmatizer_case",0},{"parser",0}};
lastIndexDone = {{"tagger",-1},{"morpho",-1},{"lemmatizer_case",-1},{"parser",-1}};
todo = {{"tagger",3*lookahead+1},{"morpho",2*lookahead+1},{"lemmatizer_case",lookahead+1}};
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("strategy_tokenizer,tagger,morpho,lemmatizer,parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config & c, Oracle *)
{
if (c.pastActions.size() == 0)
return std::string("MOVE tokenizer 0");
std::string previousState = util::noAccentLower(c.pastActions.getElem(0).first);
std::string previousAction = util::noAccentLower(c.pastActions.getElem(0).second.name);
std::string newState;
int movement = 0;
if (previousState == "tokenizer")
{
if (util::split(previousAction, ' ')[0] == "splitword" || util::split(previousAction, ' ')[0] == "endword")
newState = "tagger";
else
newState = "tokenizer";
if (util::split(previousAction, ' ')[0] == "splitword")
movement = 1;
if (c.rawInputHeadIndex >= (int)c.rawInput.size() && c.getTape("FORM").getHyp(0).empty())
{
newState = "parser";
movement = -1;
}
}
else if (previousState == "tagger")
newState = "morpho";
else if (previousState == "morpho")
{
newState = "morpho";
if (previousAction == "nothing")
newState = "lemmatizer_lookup";
}
else if (previousState == "lemmatizer_lookup")
{
if (previousAction == "notfound")
newState = "lemmatizer_rules";
else
newState = "lemmatizer_case";
}
else if (previousState == "lemmatizer_rules")
newState = "lemmatizer_case";
else if (previousState == "lemmatizer_case")
newState = "parser";
else if (previousState == "parser")
{
if (util::split(previousAction, ' ')[0] == "shift" || util::split(previousAction, ' ')[0] == "right")
{
newState = "segmenter";
movement = 0;
}
else
newState = "parser";
}
else if (previousState == "segmenter")
{
newState = "tokenizer";
movement = 1;
if (!c.getTape("ID").getHyp(1).empty())
newState = "tagger";
if (c.endOfTapes())
movement = 0;
}
else
newState = "unknown("+std::string(ERRINFO)+")("+previousState+")("+previousAction+")";
if (previousState != "tokenizer" && c.rawInputHeadIndex >= (int)c.rawInput.size() && c.getTape("FORM").getHyp(0).empty())
{
newState = "parser";
movement = 0;
}
return "MOVE " + newState + " " + std::to_string(movement);
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) != 2);
while (fscanf(fd, "%[^\t]\t%[^\n]\n", b1, b2) == 2)
oracle->data[b1] = b2;
},
[](Config & c, Oracle * oracle)
{
int window = 3;
int start = std::max<int>(c.getHead()-window, 0) - c.getHead();
int end = std::min<int>(c.getHead()+window, c.getTape("SGN").size()-1) - c.getHead();
while (start+c.getHead() < c.getTape("SGN").size() && !c.getTape("SGN").getHyp(start).empty())
start++;
while (end >= 0 && c.getTape("FORM")[end].empty())
end--;
if (start > end)
return std::string("NOTHING");
std::string action("MULTIWRITE " + std::to_string(start) + " " + std::to_string(end) + " " + std::string("SGN"));
for(int i = start; i <= end; i++)
{
const std::string & form = c.getTape("FORM").getRef(i).empty() ? c.getTape("FORM").getHyp(i) : c.getTape("FORM").getRef(i);
std::string signature;
if (oracle->data.count(form))
signature = oracle->data[form];
else if (oracle->data.count(util::noAccentLower(form)))
signature = oracle->data[util::noAccentLower(form)];
else
signature = "UNKNOWN";
action += std::string(" ") + signature;
}
return action;
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("none", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
return "NOTHING";
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
char b2[1024];
char b3[1024];
char b4[1024];
while (fscanf(fd, "%[^\t]\t%[^\t]\t%[^\t]\t%[^\n]\n", b1, b2, b3, b4) == 4)
oracle->data[std::string(b1) + std::string("_") + b2] = b3;
},
[](Config & c, Oracle * oracle)
{
const std::string & form = c.getTape("FORM")[0];
const std::string & pos = c.getTape("POS")[0];
std::string lemma;
if (c.hasTape("ID"))
{
auto & id = c.getTape("ID")[0];
if (!id.empty())
{
if (util::split(id, '-').size() > 1)
return std::string("NOTHING");
}
}
if (oracle->data.count(form + "_" + pos))
lemma = oracle->data[form + "_" + pos];
else if (oracle->data.count(util::noAccentLower(form)+"_"+pos))
lemma = oracle->data[util::noAccentLower(form) + "_" + pos];
else
return std::string("NOTFOUND");
return std::string("WRITE b.0 LEMMA ") + lemma;
},
[](Config &, Oracle *, const std::string &)
{
return 0;
})));
str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
const std::string & form = c.getTape("FORM")[0];
const std::string & lemma = c.getTape("LEMMA").getRef(0);
std::string rule = util::getRule(util::toLowerCase(form), util::toLowerCase(lemma));
return action == std::string("RULE LEMMA ON FORM ") + rule ? 0 : 1;
})));
str2oracle.emplace("lemma_case", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
const std::string & hyp = c.getTape("LEMMA")[0];
const std::string & ref = c.getTape("LEMMA").getRef(0);
if (hyp == ref)
return action == "NOTHING" ? 0 : 1;
if (util::toLowerCase(hyp) == ref)
return action == "TOLOWER b.0 LEMMA" ? 0 : 1;
if (util::toUpperCase(hyp) == ref)
return action == "TOUPPER b.0 LEMMA" ? 0 : 1;
return 1;
})));
str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle *)
{
},
[](Config &, Oracle *)
{
fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO);
exit(1);
return std::string("");
},
[](Config & c, Oracle *, const std::string & action)
{
auto & ids = c.getTape("ID");
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
int head = c.getHead();
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
if (head >= eos.size())
return action == "EOS s.0" ? 0 : 1;
if (ids.getRef(0).empty())
{
fprintf(stderr, "ERROR (%s) : ID.getRef(0) is empty. Aborting.\n", ERRINFO);
exit(1);
}
if (ids.getRef(stackHead-head).empty())
{
fprintf(stderr, "ERROR (%s) : ID.getRef(stackHead-head) is empty. Aborting.\n", ERRINFO);
exit(1);
}
bool headIsMultiword = util::split(ids.getRef(0), '-').size() > 1;
bool headIsEmptyNode = util::split(ids.getRef(0), '.').size() > 1;
int headGov = -1;
try {headGov = head + std::stoi(govs.getRef(0));}
catch (std::exception &) {headGov = -1;}
bool stackHeadIsMultiword = util::split(ids.getRef(stackHead-head), '-').size() > 1;
bool stackHeadIsEmptyNode = util::split(ids.getRef(stackHead-head), '.').size() > 1;
int stackGov = -1;
try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));}
catch (std::exception &) {stackGov = -1;}
int sentenceStart = head-1 < 0 ? 0 : head-1;
int sentenceEnd = head;
int cost = 0;
while(sentenceStart >= 0 && eos.getRef(sentenceStart-head) != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < eos.refSize() && eos.getRef(sentenceEnd-head) != ProgramParameters::sequenceDelimiter)
sentenceEnd++;
if (sentenceEnd == eos.refSize())
sentenceEnd--;
auto parts = util::split(action);
if (parts[0] == "SHIFT")
{
if (headIsMultiword || headIsEmptyNode)
return 0;
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
try
{
int sGov = s + std::stoi(govs.getRef(s-head));
if (sGov == head || headGov == s)
cost++;
}
catch (std::exception &) {continue;}
}
if (c.stackSize() && stackHead == head)
cost++;
return eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
else if (parts[0] == "WRITE" && parts.size() == 4)
{
auto object = util::split(parts[1], '.');
if (object[0] == "b")
{
if (parts[2] == "LABEL")
return (action == "WRITE b.0 LABEL " + c.getTape("LABEL").getRef(0) || c.getTape("LABEL").getRef(0) == "root") ? 0 : 1;
else if (parts[2] == "GOV")
return (action == ("WRITE b.0 GOV " + c.getTape("GOV").getRef(0))) ? 0 : 1;
}
else if (object[0] == "s")
{
int index = c.stackGetElem(-1);
if (parts[2] == "LABEL")
return (action == "WRITE s.-1 LABEL " + c.getTape("LABEL").getRef(index-head) || c.getTape("LABEL").getRef(index-head) == "root") ? 0 : 1;
else if (parts[2] == "GOV")
return (action == "WRITE s.-1 GOV " + c.getTape("GOV").getRef(index-head)) ? 0 : 1;
}
return 1;
}
else if (parts[0] == "REDUCE")
{
if (stackHeadIsMultiword || stackHeadIsEmptyNode)
return 0;
for (int i = head; i <= sentenceEnd; i++)
{
try
{
int otherGov = i + std::stoi(govs.getRef(i-head));
if (otherGov == stackHead || stackGov == i)
cost++;
}
catch (std::exception &) {continue;}
}
if (eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter)
return cost;
return cost+1;
}
else if (parts[0] == "LEFT")
{
if (stackHeadIsMultiword || headIsMultiword || headIsEmptyNode || stackHeadIsEmptyNode)
return 1;
if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter)
cost++;
for (int i = head+1; i <= sentenceEnd; i++)
{
try
{
int otherGov = i + std::stoi(govs.getRef(i-head));
if (otherGov == stackHead || stackGov == i)
cost++;
}
catch (std::exception &) {continue;}
}
if (stackGov != head)
cost++;
if (parts.size() == 1)
return cost;
if (util::split(labels.getRef(stackHead-head), ':')[0] == util::split(parts[1], ':')[0])
return cost;
return cost+1;
}
else if (parts[0] == "RIGHT")
{
if (stackHeadIsMultiword || headIsMultiword || headIsEmptyNode || stackHeadIsEmptyNode)
return 1;
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
try
{
int otherGov = s + std::stoi(govs.getRef(s-head));
if (otherGov == head || headGov == s)
cost++;
}
catch (std::exception &) {continue;}
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
cost++;
if (headGov != stackHead)
cost++;
if (parts.size() == 1)
return cost;
if (util::split(labels.getRef(0), ':')[0] == util::split(parts[1], ':')[0])
return cost;
return cost+1;
}
else if (parts[0] == "EOS")
{
for (int j = 1; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
int noGovs = -1;
if (govs.getHyp(s-head).empty())
noGovs++;
if (noGovs > 0)
cost += noGovs;
}
return eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter ? cost : cost+1;
}
return cost;
})));
str2oracle.emplace("parser_gold", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle)
{
File file(oracle->filename, "r");
FILE * fd = file.getDescriptor();
char b1[1024];
if (fscanf(fd, "Default : %[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
while (fscanf(fd, "%[^\n]\n", b1) == 1)
oracle->data[b1] = "ok";
},
[](Config & c, Oracle * oracle)
{
for (auto & it : oracle->data)
{
Action action(it.first);
if (!action.appliable(c))
continue;
if (oracle->getActionCost(c, it.first) == 0)
return it.first;
}
fprintf(stderr, "ERROR (%s) : No zero cost action found by the oracle. Aborting.\n", ERRINFO);
c.printForDebug(stderr);
for (auto & it : oracle->data)
{
Action action(it.first);
fprintf(stderr, "%s : ", action.name.c_str());
explainCostOfAction(stderr, c, action.name);
fprintf(stderr, "cost (%d)\n", oracle->getActionCost(c, it.first));
}
exit(1);
return std::string("");
},
str2oracle["parser"]->getCost)));
}
void Oracle::explainCostOfAction(FILE * output, Config & c, const std::string & action)
{
auto parts = util::split(action);
if (parts[0] == "WRITE")
{
if (parts.size() != 4)
{
fprintf(stderr, "Wrong number of action arguments\n");
return;
}
auto object = util::split(parts[1], '.');
auto tape = parts[2];
auto label = parts[3];
std::string expected;
if (object[0] == "b")
{
int index = 0;
try {index = c.getHead() + std::stoi(object[1]);}
catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
expected = c.getTape(tape).getRef(index-c.getHead());
}
else if (object[0] == "s")
{
int stackIndex = 0;
try {stackIndex = std::stoi(object[1]);}
catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);}
int bufferIndex = c.stackGetElem(stackIndex) + c.getHead();
expected = c.getTape(tape).getRef(bufferIndex-c.getHead());
}
else
{
fprintf(stderr, "ERROR (%s) : wrong action name \'%s\'. Aborting.\n", ERRINFO, action.c_str());
exit(1);
}
fprintf(output, "Wrong write (%s) expected (%s)\n", label.c_str(), expected.c_str());
return;
}
else if (parts[0] == "IGNORECHAR")
{
if (!util::isUtf8Separator(c.rawInput.begin()+c.rawInputHeadIndex))
{
fprintf(stderr, "rawInputHead is pointing to non separator character <%c>(%d)\n", c.rawInput[c.rawInputHeadIndex], c.rawInput[c.rawInputHeadIndex]);
return;
}
else if (c.rawInputHeadIndex+1 > (int)c.rawInput.size())
{
fprintf(stderr, "rawInputHeadIndex=%d rawInputSize=%lu\n", c.rawInputHeadIndex, c.rawInput.size());
return;
}
fprintf(stderr, "cannot explain\n");
return;
}
else if (parts[0] == "ENDWORD")
{
if (c.getTape("FORM").getRef(0) != c.getTape("FORM").getHyp(0))
{
fprintf(stderr, "hyp <%s> and ref <%s> are different\n", c.getTape("FORM").getHyp(0).c_str(), c.getTape("FORM").getRef(0).c_str());
return;
}
fprintf(stderr, "cannot explain\n");
return;
}
else if (parts[0] == "ADDCHARTOWORD")
{
fprintf(stderr, "cannot explain\n");
return;
}
else if (parts[0] == "SPLITWORD")
{
fprintf(stderr, "cannot explain\n");
return;
}
auto & labels = c.getTape("LABEL");
auto & govs = c.getTape("GOV");
auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape);
int head = c.getHead();
int stackHead = c.stackEmpty() ? 0 : c.stackTop();
int stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));
int headGov = head + std::stoi(govs.getRef(0));
int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1;
int sentenceEnd = c.getHead();
while(sentenceStart >= 0 && eos.getRef(sentenceStart-head) != ProgramParameters::sequenceDelimiter)
sentenceStart--;
if (sentenceStart != 0)
sentenceStart++;
while(sentenceEnd < eos.refSize() && eos.getRef(sentenceEnd-head) != ProgramParameters::sequenceDelimiter)
sentenceEnd++;
if (sentenceEnd == eos.refSize())
sentenceEnd--;
if (parts[0] == "SHIFT")
{
for (int i = sentenceStart; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.getRef(i-head));
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == i)
{
if (otherGov == head)
{
fprintf(output, "Word on stack %d(%s)\'s governor is the current getHead()\n", s, c.getTape("FORM").getRef(s-head).c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head\'s governor is on the stack %d(%s)\n", s, c.getTape("FORM").getRef(s-head).c_str());
return;
}
}
}
}
if (eos.getRef(0) != ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "REDUCE")
{
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.getRef(i-head));
if (otherGov == stackHead)
{
fprintf(output, "Stack getHead() is the governor of %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str());
return;
}
}
if (eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The top of the stack is end of sentence\n");
}
}
else if (parts[0] == "LEFT")
{
if (parts.size() == 2 && stackGov == head && labels.getRef(stackHead-head) == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
if (parts.size() == 1 && stackGov == head)
{
fprintf(output, "Zero cost\n");
return;
}
if (labels.getRef(stackHead-head) != parts[1])
{
fprintf(output, "Stack's head label %s mismatch with action label %s\n", labels.getRef(stackHead-head).c_str(), parts[1].c_str());
return;
}
for (int i = head; i <= sentenceEnd; i++)
{
int otherGov = i + std::stoi(govs.getRef(i-head));
if (otherGov == stackHead)
{
fprintf(output, "Word %d(%s)\'s governor is the stack's head\n", i, c.getTape("FORM").getRef(i-head).c_str());
return;
}
else if (stackGov == i)
{
fprintf(output, "Stack head's governor is the word %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str());
return;
}
}
if (parts.size() == 1)
{
fprintf(output, "ERROR (%s) : Unexpected situation\n", ERRINFO);
}
fprintf(output, "Unable to explain action\n");
return;
}
else if (parts[0] == "RIGHT")
{
for (int j = 0; j < c.stackSize(); j++)
{
auto s = c.stackGetElem(j);
if (s == c.stackTop())
continue;
int otherGov = s + std::stoi(govs.getRef(s-head));
if (otherGov == head)
{
fprintf(output, "The governor of %d(%s) in the stack, is the current head\n", s, c.getTape("FORM").getRef(s-head).c_str());
return;
}
else if (headGov == s)
{
fprintf(output, "The current head's governor is the stack element %d(%s)\n", s, c.getTape("FORM").getRef(s-head).c_str());
return;
}
}
for (int i = head; i <= sentenceEnd; i++)
if (headGov == i)
{
fprintf(output, "The current head's governor is the future word %d(%s)\n", i, c.getTape("FORM").getRef(i-head).c_str());
return;
}
if (parts.size() == 1)
{
fprintf(output, "Zero cost\n");
return;
}
if (labels.getRef(0) == parts[1])
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "Current head's label %s mismatch action label %s\n", labels.getRef(0).c_str(), parts[1].c_str());
return;
}
}
else if (parts[0] == ProgramParameters::sequenceDelimiterTape)
{
if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter)
{
fprintf(output, "Zero cost\n");
return;
}
else
{
fprintf(output, "The head of the stack is not end of sentence\n");
return;
}
}
fprintf(output, "Unknown reason\n");
}