diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index fec7ae7c675c73b4a574ceaf5a08195ab73d355d..65e07b08babbf3982d880406b75a4b5ceb7eeee4 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -64,13 +64,13 @@ void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff) { int totalSize = ProgramParameters::tapeSize; int steps = config.getHead(); - if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)) + if (ProgramParameters::rawInput) { - if (ProgramParameters::rawInput) - fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*config.rawInputHeadIndex/config.rawInput.size(), int2humanStr((int)currentSpeed).c_str()); - else - fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str()); + totalSize = config.rawInput.size(); + steps = config.rawInputHeadIndex; } + if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)) + fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*config.rawInputHeadIndex/config.rawInput.size(), int2humanStr((int)currentSpeed).c_str()); } } @@ -95,6 +95,8 @@ std::pair<float,std::string> getClassifierAction(Config & config, Classifier::We } std::string & predictedAction = weightedActions[0].second.second; + if (predictedAction.empty()) + throw EndOfDecode(); float proba = weightedActions[0].second.first; Action * action = classifier->getAction(predictedAction); @@ -183,12 +185,13 @@ void computeAndRecordEntropy(Config & config, Classifier::WeightedActions & weig void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & actionName, Config & config) { - if (ProgramParameters::debug) - fprintf(stderr, "Applying action=<%s>\n", actionName.c_str()); + Action * action = tm.getCurrentClassifier()->getAction(actionName); TransitionMachine::Transition * transition = tm.getTransition(actionName); action->setInfos(tm.getCurrentClassifier()->name); config.addToActionsHistory(tm.getCurrentClassifier()->name, actionName, 0); + if (ProgramParameters::debug) + fprintf(stderr, "Applying action=<%s>\n", action->name.c_str()); action->apply(config); tm.takeTransition(transition); } @@ -196,6 +199,7 @@ void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & ac void Decoder::decode() { config.reset(); + config.fillTapesWithInput(); if (ProgramParameters::beamSize > 1) decodeBeam(); @@ -232,7 +236,7 @@ void Decoder::decodeNoBeam() std::pair<float,std::string> predictedAction; try {predictedAction = getClassifierAction(config, weightedActions, tm.getCurrentClassifier(), 0);} - catch(EndOfDecode &) {continue;} + catch(EndOfDecode &) {break;} catch(NoMoreActions &) {continue;}; checkAndRecordError(config, tm.getCurrentClassifier(), weightedActions, predictedAction.second, errors); @@ -252,7 +256,7 @@ void Decoder::decodeNoBeam() if (ProgramParameters::errorAnalysis) errors.printStats(); - config.printTheRest(); + config.printTheRest(false); if (ProgramParameters::interactive) fprintf(stderr, " \n"); @@ -485,7 +489,7 @@ void Decoder::decodeBeam() for (auto node : beam) { node->config.setOutputFile(outputFile); - node->config.printTheRest(); + node->config.printTheRest(false); } if (ProgramParameters::interactive) diff --git a/maca_common/include/ProgramOutput.hpp b/maca_common/include/ProgramOutput.hpp index dca9bfeddfc6c61941ba6cbea9e829a87f2ff2de..e531d0397a3d3f092913702f23725ed82df4c24a 100644 --- a/maca_common/include/ProgramOutput.hpp +++ b/maca_common/include/ProgramOutput.hpp @@ -23,7 +23,7 @@ struct ProgramOutput public : void print(FILE * output); - void addLine(const std::vector< std::pair<std::string, float> > & line, unsigned int index); + void addLine(FILE * output, const std::vector< std::pair<std::string, float> > & line, unsigned int index); }; #endif diff --git a/maca_common/include/util.hpp b/maca_common/include/util.hpp index c9033fccf22b843fd2365007557cb974b2dc884f..245a082cd22d071acd8a03037434648552d90938 100644 --- a/maca_common/include/util.hpp +++ b/maca_common/include/util.hpp @@ -221,6 +221,7 @@ int getEndIndexOfNthSymbol(const std::string & s, int n); int getEndIndexOfNthSymbolFrom(const std::string::iterator & s, const std::string::iterator & end, int n); unsigned int getNbSymbols(const std::string & s); std::string shrinkString(const std::string & base, int maxSize, const std::string token); +std::string strip(const std::string & s); /// @brief Macro giving informations about an error. #define ERRINFO (getFilenameFromPath(std::string(__FILE__))+ ":l." + std::to_string(__LINE__)).c_str() diff --git a/maca_common/src/ProgramOutput.cpp b/maca_common/src/ProgramOutput.cpp index e9bde67e1d68c1afa92e4b7443d52635b37fed0a..bdab98c874eb63b802a42a7d49ac6462814a0466 100644 --- a/maca_common/src/ProgramOutput.cpp +++ b/maca_common/src/ProgramOutput.cpp @@ -13,12 +13,12 @@ void ProgramOutput::print(FILE * output) fprintf(output, "%s%s%s", ProgramParameters::printOutputEntropy ? ("<"+float2str(line[i].second,"%f")+">").c_str() : "", line[i].first.c_str(), i == line.size()-1 ? "\n" : "\t"); } -void ProgramOutput::addLine(const std::vector< std::pair<std::string, float> > & line, unsigned int index) +void ProgramOutput::addLine(FILE * output, const std::vector< std::pair<std::string, float> > & line, unsigned int index) { if (!ProgramParameters::delayedOutput) { for (unsigned int i = 0; i < line.size(); i++) - fprintf(stdout, "%s%s", line[i].first.c_str(), i == line.size()-1 ? "\n" : "\t"); + fprintf(output, "%s%s", line[i].first.c_str(), i == line.size()-1 ? "\n" : "\t"); return; } diff --git a/maca_common/src/util.cpp b/maca_common/src/util.cpp index 6dc2c199881971ce0eece40e4c4b115a2720d21b..20dff0dc39fa2855e95b6da275a048af76356409 100644 --- a/maca_common/src/util.cpp +++ b/maca_common/src/util.cpp @@ -568,3 +568,15 @@ std::string shrinkString(const std::string & base, int maxSize, const std::strin return result; } +std::string strip(const std::string & s) +{ + std::string res; + unsigned int i = 0; + while (i < s.size() && isSeparator(s[i])) + i++; + while (i < s.size() && !isSeparator(s[i])) + res.push_back(s[i++]); + + return res; +} + diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index 867712773ca75e81f71b812577351783d3c31607..2058d5446fe53e9fe952c02ba64e9ebabcd909b0 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -155,18 +155,64 @@ float TrainInfos::computeScoreOnTapes(Config & c, std::vector<std::string> tapes void TrainInfos::computeTrainScores(Config & c) { + std::string name; + { + File tmpOutTrain("bin/"+ProgramParameters::expName+"/tmpOutTrain.txt", "w"); + name = tmpOutTrain.getName(); + c.setOutputFile(tmpOutTrain.getDescriptor()); + c.printTheRest(false); + c.setOutputFile(nullptr); + c.setLastIndexPrinted(-1); + } + + std::string name2; + { + File tmpOutTrain("bin/"+ProgramParameters::expName+"/tmpOutTrainRef.txt", "w"); + name2 = tmpOutTrain.getName(); + c.setOutputFile(tmpOutTrain.getDescriptor()); + c.printTheRest(true); + c.setOutputFile(nullptr); + c.setLastIndexPrinted(-1); + } + + { + FILE * trainInGoodConllFormat = popen(("../tools/conlluAddMissingColumns.py " + name + " data/conllu.mcd > bin/" + ProgramParameters::expName + "/tmpOutTrain.conllu").c_str(), "w"); + pclose(trainInGoodConllFormat); + } + { + FILE * trainInGoodConllFormat = popen(("../tools/conlluAddMissingColumns.py " + name2 + " data/conllu.mcd > bin/" + ProgramParameters::expName + "/tmpOutTrainRef.conllu").c_str(), "w"); + pclose(trainInGoodConllFormat); + } + + std::map<std::string, std::string> scoresStr; + std::map<std::string, float> scoresFloat; + { + FILE * evalFromUD = popen(("../scripts/conll18_ud_eval.py " + std::string(" bin/") + ProgramParameters::expName + "/tmpOutTrainRef.conllu " + std::string(" bin/") + ProgramParameters::expName + "/tmpOutTrain.conllu -v").c_str(), "r"); + char buffer[10000]; + while (fscanf(evalFromUD, "%[^\n]\n", buffer) == 1) + { + auto splited = split(buffer, '|'); + if (splited.size() > 2) + scoresStr[strip(splited[0])] = strip(splited[3]); + } + pclose(evalFromUD); + } + + for (auto & it : scoresStr) + try {scoresFloat[it.first] = std::stof(it.second);} catch(std::exception &){} + for (auto & it : topologyPrinted) { if (it.first == "Parser") - addTrainScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); - else if (it.first == "Tagger") - addTrainScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead())); + addTrainScore(it.first, scoresFloat["MLAS"]); else if (it.first == "Tokenizer") - addTrainScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead())); + addTrainScore(it.first, scoresFloat["Tokens"]); + else if (it.first == "Tagger") + addTrainScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") - addTrainScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead())); + addTrainScore(it.first, scoresFloat["UFeats"]); else if (it.first == "Lemmatizer_Rules") - addTrainScore(it.first, computeScoreOnTapes(c, {"LEMMA"}, 0, c.getHead())); + addTrainScore(it.first, scoresFloat["Lemmas"]); else if (split(it.first, '_')[0] == "Error") addTrainScore(it.first, 100.0); else @@ -179,20 +225,50 @@ void TrainInfos::computeTrainScores(Config & c) void TrainInfos::computeDevScores(Config & c) { + std::string name; + { + File tmpOutDev("bin/"+ProgramParameters::expName+"/tmpOutDev.txt", "w"); + name = tmpOutDev.getName(); + c.setOutputFile(tmpOutDev.getDescriptor()); + c.printTheRest(false); + c.setOutputFile(nullptr); + c.setLastIndexPrinted(-1); + } + + { + FILE * devInGoodConllFormat = popen(("../tools/conlluAddMissingColumns.py " + name + " data/conllu.mcd > bin/" + ProgramParameters::expName + "/tmpOutDev.conllu").c_str(), "w"); + pclose(devInGoodConllFormat); + } + + std::map<std::string, std::string> scoresStr; + std::map<std::string, float> scoresFloat; + { + FILE * evalFromUD = popen(("../scripts/conll18_ud_eval.py " + ProgramParameters::devFilename + " bin/" + ProgramParameters::expName + "/tmpOutDev.conllu -v").c_str(), "r"); + char buffer[10000]; + while (fscanf(evalFromUD, "%[^\n]\n", buffer) == 1) + { + auto splited = split(buffer, '|'); + if (splited.size() > 2) + scoresStr[strip(splited[0])] = strip(splited[3]); + } + pclose(evalFromUD); + } + + for (auto & it : scoresStr) + try {scoresFloat[it.first] = std::stof(it.second);} catch(std::exception &){} + for (auto & it : topologyPrinted) { if (it.first == "Parser") - addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); - else if (it.first == "Parser") - addDevScore(it.first, computeScoreOnTapes(c, {"GOV", "LABEL"}, 0, c.getHead())); + addDevScore(it.first, scoresFloat["MLAS"]); else if (it.first == "Tokenizer") - addDevScore(it.first, computeScoreOnTapes(c, {"FORM"}, 0, c.getHead())); + addDevScore(it.first, scoresFloat["Tokens"]); else if (it.first == "Tagger") - addDevScore(it.first, computeScoreOnTapes(c, {"POS"}, 0, c.getHead())); + addDevScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") - addDevScore(it.first, computeScoreOnTapes(c, {"MORPHO"}, 0, c.getHead())); + addDevScore(it.first, scoresFloat["UFeats"]); else if (it.first == "Lemmatizer_Rules") - addDevScore(it.first, computeScoreOnTapes(c, {"LEMMA"}, 0, c.getHead())); + addDevScore(it.first, scoresFloat["Lemmas"]); else if (split(it.first, '_')[0] == "Error") addDevScore(it.first, 100.0); else diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index dd9fb084a4090908a00d2b869a944c94f9447ec3..9aafaf4e94b83783da3f9bcc0acc2e5790733e87 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -49,6 +49,7 @@ void Trainer::computeScoreOnDev() tm.reset(); devConfig->reset(); + devConfig->fillTapesWithInput(); if (ProgramParameters::debug) fprintf(stderr, "Computing score on dev set\n"); @@ -62,7 +63,7 @@ void Trainer::computeScoreOnDev() auto pastTime = std::chrono::high_resolution_clock::now(); std::vector<float> entropies; - while (!devConfig->isFinal()) + while (true) { setDebugValue(); devConfig->setCurrentStateName(tm.getCurrentClassifier()->name); @@ -71,8 +72,11 @@ void Trainer::computeScoreOnDev() if(!tm.getCurrentClassifier()->needsTrain()) { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig); - std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); + std::string neededActionName = tm.getCurrentClassifier()->getOracleAction(*devConfig); + + if (neededActionName.empty()) + break; + Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(tm.getCurrentClassifier()->name); @@ -87,6 +91,11 @@ void Trainer::computeScoreOnDev() { int totalSize = ProgramParameters::devTapeSize; int steps = devConfig->getHead(); + if (devConfig->rawInputHeadIndex > 0) + { + totalSize = devConfig->rawInput.size(); + steps = devConfig->rawInputHeadIndex; + } if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)) { fprintf(stderr, " \r"); @@ -112,7 +121,11 @@ void Trainer::computeScoreOnDev() } if (pAction.empty()) + { + if (ProgramParameters::debug) + fprintf(stderr, "No action predicted\n"); break; + } if (ProgramParameters::devLoss) { @@ -197,13 +210,15 @@ void Trainer::resetAndShuffle() trainConfig.reset(); if(ProgramParameters::shuffleExamples) - trainConfig.shuffle(ProgramParameters::sequenceDelimiterTape, ProgramParameters::sequenceDelimiter); + trainConfig.shuffle(); + + trainConfig.fillTapesWithInput(); } void Trainer::doStepNoTrain() { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig); - std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); + std::string neededActionName = tm.getCurrentClassifier()->getOracleAction(trainConfig); + if (ProgramParameters::debug) { fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str()); @@ -211,6 +226,9 @@ void Trainer::doStepNoTrain() fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); } + if (neededActionName.empty()) + throw EndOfIteration(); + Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(tm.getCurrentClassifier()->name); @@ -233,6 +251,11 @@ void Trainer::doStepTrain() { int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize; int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps; + if (trainConfig.rawInputHeadIndex > 0) + { + totalSize = trainConfig.rawInput.size(); + steps = trainConfig.rawInputHeadIndex; + } if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff) { fprintf(stderr, " \r"); @@ -270,7 +293,14 @@ void Trainer::doStepTrain() } if (oAction.empty()) + { oAction = tm.getCurrentClassifier()->getDefaultAction(); + if(!tm.getCurrentClassifier()->getAction(oAction)->appliable(trainConfig)) + oAction.clear(); + } + + if (oAction.empty()) + oAction = pAction; if (oAction.empty()) { @@ -301,7 +331,7 @@ void Trainer::doStepTrain() fprintf(stdout, "%s\t%s\t%s\n", tm.getCurrentClassifier()->getFeatureModel()->filename.c_str(), oAction.c_str(), features.c_str()); } - if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + if (tm.getCurrentClassifier()->isDynamic() && TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) { actionName = pAction; TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true; @@ -510,7 +540,7 @@ void Trainer::train() while (TI.getEpoch() <= ProgramParameters::nbIter) { resetAndShuffle(); - while (!trainConfig.isFinal()) + while (true) { setDebugValue(); trainConfig.setCurrentStateName(tm.getCurrentClassifier()->name); @@ -518,7 +548,8 @@ void Trainer::train() tm.getCurrentClassifier()->initClassifier(trainConfig); if(!tm.getCurrentClassifier()->needsTrain()) - doStepNoTrain(); + try {doStepNoTrain();} + catch (EndOfIteration &) {break;} else try {doStepTrain();} catch (EndOfIteration &) {break;} @@ -544,6 +575,8 @@ void Trainer::train() void Trainer::printScoresAndSave(FILE * output) { + trainConfig.transformSymbol("", "_"); + devConfig->transformSymbol("", "_"); TI.computeTrainScores(trainConfig); computeScoreOnDev(); TI.computeMustSaves(); diff --git a/transition_machine/include/BD.hpp b/transition_machine/include/BD.hpp index 7e7aea8931d640d9375037255c12d52f54b9e094..bdffaaf26d4bce8da9dbef6bd48f215957973678 100644 --- a/transition_machine/include/BD.hpp +++ b/transition_machine/include/BD.hpp @@ -35,6 +35,8 @@ class BD /// /// If this Line's values don't need to be predicted. bool isKnown; + /// @brief What column will it be in the output. + int outputIndex; /// @brief Create a new Line. /// @@ -44,7 +46,7 @@ class BD /// @param inputColumn The column of the MCD this Line corresponds to. /// @param mustPrint Whether or not this Line is part of the expected output of the program. /// @param isKnown Whether or not the entirety of this Line is already known. - Line(int num, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown); + Line(int num, int outputIndex, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown); }; private : @@ -111,6 +113,12 @@ class BD /// /// @return Whether or not this Line is part of the output. bool mustPrintLine(int index); + /// @brief Return the column index of line index in the outut. + /// + /// @param index The index of the Line. + /// + /// @return The column index of this line in the output. + int getOutputIndexOfLine(int index); /// @brief Get the name of a Line from its index. /// /// @param line The index of the Line. diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index fb134400a6145214f381f70a44b811373f96f46a..d85935bfcc6089725e95d8af7a87e8957921395f 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -62,6 +62,8 @@ class Classifier /// the correct Action to associate with a training example.\n /// For Classifier of type Information, the Oracle is used in train mode and decode mode too, it is simply a deterministic function that gives the correct Action given a Configuration. Oracle * oracle; + /// @brief Is this classifier subject to a dynamic oracle. + bool dynamic; private : @@ -238,6 +240,8 @@ class Classifier FeatureModel * getFeatureModel(); /// @brief Prepare Classifier for next iteration. void endOfIteration(); + /// @brief Is this classifier subject to a dynamic oracle. + bool isDynamic(); }; #endif diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index f988f2bf638f4edbc1a353fd69861ca3374bd88b..d89f93456c5d9a7b0db1a7edd2dad87bca8c2f97 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -73,6 +73,13 @@ class Config /// @param relativeIndex The index of the cell relatively to the head. /// @param elem The new content of the cell. void setHyp(int relativeIndex, const std::string & elem); + /// @brief Set the value of a cell of the ref. + /// + /// @param relativeIndex The index of the cell relatively to the head. + /// @param elem The new content of the cell. + void setRef(int relativeIndex, const std::string & elem); + void set(int relativeIndex, const std::string & elem); + int getHead(); /// @brief Return true if the head of this tape is on the last cell. /// /// @return True if the head of this tape is on the last cell. @@ -189,6 +196,8 @@ class Config int rawInputHeadIndex; /// @brief Index of current word in the sentence, as in conll format. int currentWordIndex; + /// @brief The conll input as it was read. + std::vector< std::vector<std::string> > inputContent; public : @@ -221,6 +230,7 @@ class Config Tape & getTapeByInputCol(int col); /// @brief Read a part of a formated input file (mcf) and use it to fill the tapes. void readInput(); + void fillTapesWithInput(); /// @brief Print the Config for debug purposes. /// /// @param output Where to print. @@ -230,7 +240,8 @@ class Config /// @param output Where to print. /// @param dataIndex Index of line to print. /// @param realIndex Index of line to print. - void printAsOutput(FILE * output, int dataIndex, int realIndex); + /// @param forceRef True to force the output to be the ref tape. + void printAsOutput(FILE * output, int dataIndex, int realIndex, bool forceRef); /// @brief Print the Config without information loss. /// /// @param output Where to print. @@ -274,13 +285,8 @@ class Config /// /// @return The history of entropies of the current state in the TransitionMachine. LimitedStack<float> & getCurrentStateEntropyHistory(); - /// @brief Shuffle the segments of the Config. - /// - /// For instance if you call shuffle("EOS", "1");\n - /// Sentences will be preserved, but their order will be shuffled. - /// @param delimiterTape The tape containing the delimiters of segments. - /// @param delimiter The delimiters of segments. - void shuffle(const std::string & delimiterTape, const std::string & delimiter); + /// @brief Shuffle the Config per sequences. + void shuffle(); /// @brief Get element from the stack at depth index. /// /// @param index The depth of the requested element. @@ -352,12 +358,10 @@ class Config /// /// @return True if the head is at the end of the tapes. bool endOfTapes() const; - /// @brief Update rawInput according to the tape TEXT. - void updateRawInput(); /// @brief Set the output file. void setOutputFile(FILE * outputFile); /// @brief Print the cells that have not been printed. - void printTheRest(); + void printTheRest(bool forceRef); void setEntropy(float entropy); float getEntropy() const; void addToEntropy(float entropy); @@ -367,6 +371,12 @@ class Config void printColumnInfos(unsigned int index); void addToActionsHistory(std::string & state, const std::string & action, int cost); std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & state); + void transformSymbol(const std::string & from, const std::string & to); + void setLastIndexPrinted(int lastIndexPrinted); + /// @brief Transform the tape GOV from relative indexes to UD format. + void setGovsAsUD(bool ref); + /// @brief Update the IDs in the last predicted sequence. + void updateIdsInSequence(); }; #endif diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 26d167ed843f308e0c44da1de875abfd5ba658bd..d68d017de9de645b65d6129c5660044cb5742099 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -163,7 +163,7 @@ Action::BasicAction ActionBank::pushHead() auto undo = [](Config & c, Action::BasicAction &) {c.stackPop();}; auto appliable = [](Config & c, Action::BasicAction &) - {return !(c.stackSize() >= ProgramParameters::maxStackSize) && (!c.endOfTapes());}; + {return !(c.stackSize() >= ProgramParameters::maxStackSize || (!c.stackEmpty() && c.stackTop() == c.getHead()));}; Action::BasicAction basicAction = {Action::BasicAction::Type::Push, "", apply, undo, appliable}; @@ -184,7 +184,7 @@ Action::BasicAction ActionBank::stackPop(bool checkGov) if (!checkGov) return true; - return !c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty(); + return split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1 || (!c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty() && c.stackTop() != c.getHead()); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Pop, "", apply, undo, appliable}; @@ -321,6 +321,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } else if(std::string(b1) == "ADDCHARTOWORD") { + sequence.emplace_back(increaseTapesIfNeeded(0)); + auto apply = [](Config & c, Action::BasicAction &) {addCharToBuffer(c, "FORM", 0);}; auto undo = [](Config & c, Action::BasicAction &) @@ -426,6 +428,12 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na return false; int b0 = c.getHead(); int s0 = c.stackTop(); + + if (split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + return simpleBufferWriteAppliable(c, "GOV", s0-b0); }; Action::BasicAction basicAction = @@ -477,7 +485,13 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na }; auto appliable = [](Config & c, Action::BasicAction &) { - return !c.stackEmpty() && !c.endOfTapes() && simpleBufferWriteAppliable(c, "GOV", 0); + if (c.stackEmpty()) + return false; + if (split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + return simpleBufferWriteAppliable(c, "GOV", 0); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, "", apply, undo, appliable}; @@ -508,6 +522,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } else if(std::string(b1) == "EOS") { + // Puting the EOS tag on the last element of the sentence. auto apply0 = [b2](Config & c, Action::BasicAction &) { int b0 = c.getHead(); @@ -529,22 +544,27 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(basicAction0); + // Chosing root of the sentence and attaching floating words to it. auto apply = [](Config & c, Action::BasicAction & ba) { ba.data = ""; auto & govs = c.getTape("GOV"); + auto & ids = c.getTape("ID"); int b0 = c.getHead(); int rootIndex = -1; for (int i = c.stackSize()-1; i >= 0; i--) { auto s = c.stackGetElem(i); + if (split(ids.getRef(s-b0), '-').size() > 1) + continue; if (govs.getHyp(s-b0).empty() || govs.getHyp(s-b0) == "0") { if (rootIndex == -1) rootIndex = s; else { - simpleBufferWrite(c, "GOV", std::to_string(rootIndex - s), s-b0); + simpleBufferWrite(c, "GOV", std::to_string(rootIndex-s), s-b0); + simpleBufferWrite(c, "LABEL", "_", s-b0); ba.data += "+"+std::to_string(s-b0); } } @@ -563,26 +583,10 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } simpleBufferWrite(c, "GOV", "0", rootIndex-b0); + simpleBufferWrite(c, "LABEL", "root", rootIndex-b0); // Delete the arcs from the previous sentence to the new sentence - - auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); - - for (int i = b0; i >= 0; i--) - { - if (eos[i-b0] == ProgramParameters::sequenceDelimiter) - break; - - try - { - int govIndex = i + std::stoi(govs[i-b0]); - if (govIndex <= c.stackGetElem(0)) - { - simpleBufferWrite(c, "GOV", std::to_string(rootIndex - i), i-b0); - simpleBufferWrite(c, "LABEL", "_", i-b0); - } - } catch (std::exception &) {} - } + // TODO }; auto undo = [](Config & c, Action::BasicAction & ba) { @@ -594,13 +598,17 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na if (govs.getHyp(s-b0) == "0") { simpleBufferWrite(c, "GOV", "", s-b0); + simpleBufferWrite(c, "LABEL", "", s-b0); break; } } auto deps = split(ba.data, '+'); for (auto s : deps) if (!s.empty()) + { simpleBufferWrite(c, "GOV", "", std::stoi(s)); + simpleBufferWrite(c, "LABEL", "", std::stoi(s)); + } ba.data.clear(); }; auto appliable = [](Config & c, Action::BasicAction &) @@ -612,61 +620,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(basicAction); - auto apply2 = [b2](Config & c, Action::BasicAction & ba) - { - ba.data = ""; - auto & labels = c.getTape("LABEL"); - int b0 = c.getHead(); - int rootIndex = -1; - - for (int i = c.stackSize()-1; i >= 0; i--) - { - auto s = c.stackGetElem(i); - if (labels.getHyp(s-b0).empty()) - { - if (rootIndex == -1) - { - rootIndex = 0; - simpleBufferWrite(c, "LABEL", "root", s-b0); - } - else - { - simpleBufferWrite(c, "LABEL", "_", s-b0); - ba.data += "+"+std::to_string(s-b0); - } - } - } - }; - auto undo2 = [](Config & c, Action::BasicAction & ba) - { - auto & labels = c.getTape("LABEL"); - int b0 = c.getHead(); - - for (int i = c.stackSize()-1; i >= 0; i--) - { - auto s = c.stackGetElem(i); - if (labels.getHyp(s-b0) == "root") - { - simpleBufferWrite(c, "LABEL", "", s-b0); - break; - } - } - auto deps = split(ba.data, '+'); - for (auto & dep : deps) - if (!dep.empty()) - simpleBufferWrite(c, "LABEL", "", std::stoi(dep)); - ba.data.clear(); - }; - auto appliable2 = [](Config & c, Action::BasicAction &) - { - return !c.stackEmpty(); - }; - Action::BasicAction basicAction2 = - {Action::BasicAction::Type::Write, "", apply2, undo2, appliable2}; - - sequence.emplace_back(basicAction2); - - auto apply4 = [b2](Config & c, Action::BasicAction & ba) + // Empty the stack. + auto apply4 = [](Config & c, Action::BasicAction & ba) { ba.data = ""; for (int i = c.stackSize()-1; i >= 0; i--) @@ -694,6 +649,23 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na {Action::BasicAction::Type::Pop, "", apply4, undo4, appliable4}; sequence.emplace_back(basicAction4); + + // Update the IDs of the words in the new sentence + auto apply5 = [](Config & c, Action::BasicAction &) + { + c.updateIdsInSequence(); + }; + auto undo5 = [](Config &, Action::BasicAction &) + { + }; + auto appliable5 = [](Config &, Action::BasicAction &) + { + return true; + }; + Action::BasicAction basicAction5 = + {Action::BasicAction::Type::Write, "", apply5, undo5, appliable5}; + + sequence.emplace_back(basicAction5); } else if(std::string(b1) == "BACK") { @@ -831,10 +803,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, int index = config.getHead() + relativeIndex; - if (config.endOfTapes()) - return true; - - return !(index < 0) && index < tape.size(); + return !(index < 0) && index < tape.size() && tape.getHyp(relativeIndex).empty(); } bool ActionBank::isRuleAppliable(Config & config, diff --git a/transition_machine/src/BD.cpp b/transition_machine/src/BD.cpp index 957c01a23a5a364b4032e6778991ade3cc859b0c..42f337fe22f5319eb1a8c5dc5b393b6ba5539bad 100644 --- a/transition_machine/src/BD.cpp +++ b/transition_machine/src/BD.cpp @@ -2,9 +2,10 @@ #include "File.hpp" #include "util.hpp" -BD::Line::Line(int num, std::string name, std::string dictName, +BD::Line::Line(int num, int outputIndex, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown) { + this->outputIndex = outputIndex; this->dict = dictName; this->num = num; this->name = name; @@ -40,17 +41,6 @@ BD::BD(const std::string & BDfilename, const std::string & MCDfilename) exit(1); } - if(mcdCol2Str.find(col) != mcdCol2Str.end()) - { - fprintf(stderr, "ERROR (%s) : MCD column \'%d\' already exists. Aborting.\n", ERRINFO, col); - exit(1); - } - if(mcdStr2Col.find(name) != mcdStr2Col.end()) - { - fprintf(stderr, "ERROR (%s) : MCD column \'%s\' already exists. Aborting.\n", ERRINFO, name); - exit(1); - } - mcdCol2Str[col] = name; mcdStr2Col[name] = col; } @@ -63,7 +53,9 @@ BD::BD(const std::string & BDfilename, const std::string & MCDfilename) if(buffer[0] == '#') continue; - if(sscanf(buffer, "%s %s %s %s %d", name, refHyp, dict, policy, &mustPrint) != 5) + int outputIndex; + + if(sscanf(buffer, "%d %s %s %s %s %d", &outputIndex, name, refHyp, dict, policy, &mustPrint) != 6) { fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid BD line. Aborting.\n", ERRINFO, buffer); exit(1); @@ -79,7 +71,7 @@ BD::BD(const std::string & BDfilename, const std::string & MCDfilename) int inputColumn = mcdStr2Col.find(name) == mcdStr2Col.end() ? -1 : mcdStr2Col[name]; - lines.emplace_back(new Line(lines.size(), name, dict, inputColumn, mustPrint == 1, known)); + lines.emplace_back(new Line(lines.size(), outputIndex, name, dict, inputColumn, mustPrint == 1, known)); Line * line = lines.back().get(); num2line.emplace(line->num, line); name2line.emplace(line->name, line); @@ -191,3 +183,8 @@ bool BD::lineIsKnown(int line) return lines[line]->isKnown; } +int BD::getOutputIndexOfLine(int index) +{ + return lines[index]->outputIndex; +} + diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index f68f1cc7e60f9c68cfd0acb217f8e2d2574297a6..413aac277cfca6e72e217a2622da9e28bc0d39ce 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -80,6 +80,13 @@ Classifier::Classifier(const std::string & filename, bool trainMode) int batchsizeRead = 0; if(fscanf(fd, "Batchsize : %d\n", &batchsizeRead) == 1) batchSize = batchsizeRead; + + dynamic = false; + if(fscanf(fd, "Dynamic : %[^\n]\n", buffer) == 1) + { + if (!strcmp(buffer, "yes")) + dynamic = true; + } } Classifier::Type Classifier::str2type(const std::string & s) @@ -278,15 +285,25 @@ std::vector<std::string> Classifier::getZeroCostActions(Config & config) if (a.appliable(config) && oracle->getActionCost(config, a.name) == 0) result.emplace_back(a.name); + if (ProgramParameters::debug) + { + fprintf(stderr, "Zero cost actions : "); + for (auto & s : result) + fprintf(stderr, "<%s>", s.c_str()); + fprintf(stderr, "\n"); + } + if (result.empty() && as->hasDefaultAction) - result.emplace_back(as->getDefaultAction()->name); + if (as->getDefaultAction()->appliable(config)) + result.emplace_back(as->getDefaultAction()->name); return result; } std::string Classifier::getDefaultAction() const { - return as->getDefaultAction()->name; + if (as->hasDefaultAction) + return as->getDefaultAction()->name; return std::string(); } @@ -419,3 +436,8 @@ void Classifier::endOfIteration() nn->endOfIteration(); } +bool Classifier::isDynamic() +{ + return dynamic; +} + diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index 28719567e6ee096457928aa634c63be06d0b363f..5bcfc8daddbdf5774788520c33db186ede7106fa 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -20,6 +20,7 @@ Config::Config(BD & bd, const std::string inputFilename) : bd(bd), hashHistory(H for(int i = 0; i < bd.getNbLines(); i++) tapes.emplace_back(bd.getNameOfLine(i), bd.lineIsKnown(i)); this->totalEntropy = 0; + readInput(); } Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHistory), pastActions(other.pastActions) @@ -44,7 +45,7 @@ Config::Config(const Config & other) : bd(other.bd), hashHistory(other.hashHisto this->file.reset(new File(*other.file.get())); } -Config::Tape::Tape(const std::string & name, bool isKnown) : ref(ProgramParameters::readSize*4+1, Dict::unknownValueStr), hyp(ProgramParameters::readSize*4+1, std::make_pair(Dict::unknownValueStr, 0.0)) +Config::Tape::Tape(const std::string & name, bool isKnown) : ref(ProgramParameters::readSize, Dict::unknownValueStr), hyp(ProgramParameters::readSize, std::make_pair(Dict::unknownValueStr, 0.0)) { this->head = 0; this->name = name; @@ -94,76 +95,143 @@ void Config::readInput() FILE * fd = file->getDescriptor(); char buffer[100000]; - std::vector<std::string> cols; - unsigned int usualColsSize = 0; - int toRead = ProgramParameters::readSize; - int haveRead = 0; + int lineIndex = 0; - while(haveRead < toRead && fscanf(fd, "%[^\n]\n", buffer) == 1) + while (fscanf(fd, "%[^\n]\n", buffer) == 1) { + lineIndex++; + if (!utf8::is_valid(buffer, buffer+std::strlen(buffer))) { - fprintf(stderr, "ERROR (%s) : input (%s) line %d is not toally utf-8 formated. Aborting.\n", ERRINFO, inputFilename.c_str(), tapes[0].size()); + fprintf(stderr, "ERROR (%s) : input (%s) line %d is not toally utf-8 formated. Aborting.\n", ERRINFO, inputFilename.c_str(), lineIndex); exit(1); } - cols = split(buffer, '\t'); - if (!usualColsSize) - usualColsSize = cols.size(); + if (std::strlen(buffer) <= 3) + continue; - if (cols.size() != usualColsSize) + if (split(buffer, '=')[0] == "# sent_id ") + inputContent.emplace_back(); + else if (buffer[0] == '#' && split(buffer, '=')[0] != "# text ") + continue; + + inputContent.back().emplace_back(buffer); + } + + inputAllRead = true; + fillTapesWithInput(); +} + +void Config::fillTapesWithInput() +{ + rawInput = ""; + std::vector<std::string> cols; + unsigned int usualColsSize = 0; + auto & ids = getTape("ID"); + auto & govs = getTape("GOV"); + + for (auto & sentence : inputContent) + { + int sentenceStartIndex = ids.refSize(); + for (unsigned int wordIndex = 0; wordIndex < sentence.size(); wordIndex++) { - fprintf(stderr, "ERROR (%s) : input (%s) line %d has %lu columns instead of %u. Aborting.\n", ERRINFO, inputFilename.c_str(), tapes[0].size(), cols.size(), usualColsSize); - exit(1); + auto & word = sentence[wordIndex]; + if (split(word, '=')[0] == "# text ") + { + std::string prefix = rawInput.empty() ? "" : " "; + if (choiceWithProbability(0.3)) + prefix = "\n"; + else if (choiceWithProbability(0.3)) + prefix = ""; + if (rawInput.empty()) + prefix = ""; + rawInput += prefix + std::string(word.begin()+9, word.end()); + continue; + } + else if (word[0] == '#') + continue; + + cols = split(word, '\t'); + if (!usualColsSize) + usualColsSize = cols.size(); + + if (cols.size() != usualColsSize) + { + fprintf(stderr, "ERROR (%s) : input (%s) line %d has %lu columns instead of %u. Aborting.\n", ERRINFO, inputFilename.c_str(), tapes[0].size(), cols.size(), usualColsSize); + exit(1); + } + + for(unsigned int i = 0; i < cols.size(); i++) + if(bd.hasLineOfInputCol(i)) + { + auto & tape = getTapeByInputCol(i); + + tape.addToRef(cols[i]); + tape.addToHyp(""); + + if (tape.getName() == ProgramParameters::tapeToMask) + if (choiceWithProbability(ProgramParameters::maskRate)) + tape.maskIndex(tape.refSize()-1); + if (tape.getName() == ProgramParameters::sequenceDelimiterTape) + { + fprintf(stderr, "ERROR (%s) : Tape \'%s\' must not be given as a column in the input since it's the sequence delimiter. Aborting.\n", ERRINFO, tape.getName().c_str()); + exit(1); + } + } + getTape(ProgramParameters::sequenceDelimiterTape).addToRef(wordIndex == sentence.size()-1 ? ProgramParameters::sequenceDelimiter : ""); + getTape(ProgramParameters::sequenceDelimiterTape).addToHyp(""); } - printAsOutput(outputFile, tapes[0].getNextOverridenDataIndex(), tapes[0].getNextOverridenRealIndex()); + for (int word = sentenceStartIndex; word < ids.refSize(); word++) + { + if (split(ids.getRef(word), '-').size() > 1) + continue; + if (govs.getRef(word) == "0") + continue; - for(unsigned int i = 0; i < cols.size(); i++) - if(bd.hasLineOfInputCol(i)) + try { - auto & tape = getTapeByInputCol(i); + int id = std::stoi(ids.getRef(word)); + std::string goalId = govs.getRef(word); + int relativeIndex = 0; - tape.addToRef(cols[i]); - tape.addToHyp(""); + if (std::stoi(goalId) < id) + { + while (ids.getRef(word+relativeIndex) != goalId) + { + if (--relativeIndex+word < 0) + throw ""; + } + } + else + { + while (ids.getRef(word+relativeIndex) != goalId) + if (++relativeIndex+word >= ids.refSize()) + throw ""; + } - if (tape.getName() == ProgramParameters::tapeToMask) - if (choiceWithProbability(ProgramParameters::maskRate)) - tape.maskIndex(tape.refSize()-1); + govs.setRef(word, std::to_string(relativeIndex)); } - - haveRead++; + catch (std::exception &) + { + fprintf(stderr, "ERROR (%s) : invalid governor '%s' '%s'. Aborting.\n", ERRINFO, govs.getRef(word).c_str(), getTape("FORM").getRef(word).c_str()); + exit(1); + } + } } // Making all tapes the same size int maxTapeSize = 0; for(auto & tape : tapes) maxTapeSize = std::max<unsigned int>(maxTapeSize, tape.refSize()); - - if (haveRead < toRead || tapes[0].size() == ProgramParameters::tapeSize) - { - printAsOutput(outputFile, tapes[0].getNextOverridenDataIndex(), tapes[0].getNextOverridenRealIndex()); - inputAllRead = true; - } - for(auto & tape : tapes) { - while(tape.refSize() < maxTapeSize) + while (tape.refSize() < maxTapeSize) tape.addToRef(""); - - while(tape.hypSize() < maxTapeSize) + while (tape.hypSize() < maxTapeSize) tape.addToHyp(""); - - if (inputAllRead) - { - tape.addToRef("0"); - tape.addToHyp(""); - } } - - if (hasTape("TEXT")) - updateRawInput(); } void Config::printForDebug(FILE * output) @@ -233,7 +301,7 @@ void Config::printAsExample(FILE *) exit(1); } -void Config::printAsOutput(FILE * output, int dataIndex, int realIndex) +void Config::printAsOutput(FILE * output, int dataIndex, int realIndex, bool forceRef) { if (dataIndex == -1 || !output) return; @@ -243,27 +311,60 @@ void Config::printAsOutput(FILE * output, int dataIndex, int realIndex) std::vector< std::pair<std::string, float> > toPrint; for (unsigned int j = 0; j < tapes.size(); j++) { + int outputTapeIndex = bd.getOutputIndexOfLine(j); + + while ((int)toPrint.size() < outputTapeIndex+1) + toPrint.emplace_back("_", 0.0); + if(bd.mustPrintLine(j)) - toPrint.emplace_back(tapes[j][dataIndex-head].empty() ? "_" : tapes[j][dataIndex-head].c_str(), tapes[j].getEntropy(dataIndex-head)); + { + if (!forceRef) + toPrint[outputTapeIndex] = {tapes[j][dataIndex-head].empty() ? "_" : tapes[j][dataIndex-head].c_str(), tapes[j].getEntropy(dataIndex-head)}; + else + toPrint[outputTapeIndex] = {tapes[j].getRef(dataIndex-head).empty() ? "_" : tapes[j].getRef(dataIndex-head).c_str(), tapes[j].getEntropy(dataIndex-head)}; + } + } + + bool allEmpty = true; + + for (auto & it : toPrint) + if (it.first != "_" && !it.first.empty()) + { + allEmpty = false; + break; + } + + if (allEmpty) + return; + + ProgramOutput::instance.addLine(output, toPrint, realIndex); + + if (!ProgramParameters::delayedOutput) + { + auto eos = forceRef ? getTape(ProgramParameters::sequenceDelimiterTape).getRef(dataIndex-head) : getTape(ProgramParameters::sequenceDelimiterTape)[dataIndex-head]; + if (eos == ProgramParameters::sequenceDelimiter) + fprintf(output, "\n"); } - ProgramOutput::instance.addLine(toPrint, realIndex); } void Config::moveHead(int mvt) { - if (head + mvt < tapes[0].size()) + if (head + mvt <= tapes[0].size()) { head += mvt; - if (hasTape("ID") && split(getTape("ID").getHyp(0), '-').size() <= 1) - currentWordIndex += mvt; + if (mvt > 0) + for (int i = 0; i < mvt; i++) + if (hasTape("ID") && split(getTape("ID").getHyp(i), '-').size() <= 1) + currentWordIndex += 1; + if (mvt < 0) + for (int i = 0; i < mvt; i++) + if (hasTape("ID") && split(getTape("ID").getHyp(-i), '-').size() <= 1) + currentWordIndex -= 1; for (auto & tape : tapes) tape.moveHead(mvt); - - if (mvt > 0 && head % ProgramParameters::readSize == 0 && head >= (3*ProgramParameters::readSize)) - readInput(); } else if (!endOfTapes()) { @@ -296,10 +397,10 @@ void Config::moveRawInputHead(int mvt) bool Config::isFinal() { - if (!ProgramParameters::rawInput) - return endOfTapes() && stack.empty(); + if (rawInputHeadIndex > 0 && !rawInput.empty()) + return (rawInputHeadIndex >= (int)rawInput.size()); - return (rawInputHeadIndex >= (int)rawInput.size()); + return endOfTapes() && stack.empty(); } void Config::reset() @@ -316,15 +417,10 @@ void Config::reset() stack.clear(); stackHistory = -1; - inputAllRead = false; head = 0; rawInputHead = 0; rawInputHeadIndex = 0; currentWordIndex = 1; - - file.reset(); - while (tapes[0].size() < ProgramParameters::readSize*4 && !inputAllRead) - readInput(); } const std::string & Config::Tape::operator[](int relativeIndex) @@ -358,6 +454,19 @@ void Config::Tape::setHyp(int relativeIndex, const std::string & elem) hyp.set(head + relativeIndex, std::pair<std::string,float>(elem,totalEntropy)); } +void Config::Tape::setRef(int relativeIndex, const std::string & elem) +{ + ref.set(head + relativeIndex, elem); +} + +void Config::Tape::set(int relativeIndex, const std::string & elem) +{ + if(isKnown) + return setRef(relativeIndex, elem); + + return setHyp(relativeIndex, elem); +} + std::string & Config::getCurrentStateName() { if(currentStateName.empty()) @@ -398,59 +507,9 @@ LimitedStack<float> & Config::getCurrentStateEntropyHistory() return entropyHistory.find(getCurrentStateName())->second; } -void Config::shuffle(const std::string & delimiterTape, const std::string & delimiter) +void Config::shuffle() { - struct Trio{unsigned int a; unsigned int b; unsigned int c; Trio(unsigned int a, unsigned int b, unsigned int c): a(a), b(b), c(c){}}; - std::vector<Trio> delimiters; - - if (delimiterTape == "0") - { - unsigned int previousIndex = 0; - for (int i = 0; i < tapes[0].refSize(); i++) - { - delimiters.emplace_back(previousIndex, i, delimiters.size()); - previousIndex = i+1; - } - } - else - { - auto & tape = getTape(delimiterTape); - unsigned int previousIndex = 0; - for (int i = 0; i < tape.refSize(); i++) - if (tape.getRef(i-head) == delimiter) - { - delimiters.emplace_back(previousIndex, i, delimiters.size()); - previousIndex = i+1; - } - } - - if (delimiters.empty()) - { - fprintf(stderr, "WARNING (%s) : Requested to shuffle based on tape \'%s\' with \'%s\' as a delimiter, but none has been found. Aborting.\n", ERRINFO, delimiterTape.c_str(), delimiter.c_str()); - return; - } - - std::pair<unsigned int, unsigned int> suffix = {delimiters.back().b+1, tapes[0].refSize()-1}; - - std::random_shuffle(delimiters.begin(), delimiters.end()); - - auto newTapes = tapes; - - for (unsigned int tape = 0; tape < tapes.size(); tape++) - { - newTapes[tape].clearDataForCopy(); - - for (auto & delimiter : delimiters) - newTapes[tape].copyPart(tapes[tape], delimiter.a, delimiter.b+1); - - if (suffix.first <= suffix.second) - newTapes[tape].copyPart(tapes[tape], suffix.first, suffix.second+1); - } - - tapes = newTapes; - - if (!rawInput.empty()) - updateRawInput(); + std::random_shuffle(inputContent.begin(), inputContent.end()); } int Config::stackGetElem(int index) const @@ -568,12 +627,12 @@ void Config::Tape::moveHead(int mvt) bool Config::endOfTapes() const { - return inputAllRead && tapes[0].headIsAtEnd(); + return inputAllRead && (tapes[0].headIsAtEnd() || rawInputHeadIndex >= (int)rawInput.size()); } bool Config::Tape::headIsAtEnd() const { - return head == ref.getLastIndex(); + return head >= ref.getLastIndex(); } int Config::Tape::size() @@ -641,23 +700,26 @@ int Config::Tape::getNextOverridenRealIndex() return ref.getNextOverridenRealIndex(); } -void Config::printTheRest() +void Config::printTheRest(bool forceRef) { if (!outputFile) return; + updateIdsInSequence(); + setGovsAsUD(forceRef); + int tapeSize = tapes[0].size(); int goalPrintIndex = lastIndexPrinted; - int realIndex = tapeSize - 1 - ((((tapes[0].dataSize()-(goalPrintIndex == -1 ? 0 : 0)))-(goalPrintIndex+1))+(goalPrintIndex)); - for (int i = goalPrintIndex+1; i < (tapes[0].dataSize()-(goalPrintIndex == -1 ? 1 : 0)); i++) + int realIndex = tapeSize - ((((tapes[0].dataSize()-(goalPrintIndex == -1 ? 0 : 0)))-(goalPrintIndex+1))+(goalPrintIndex)); + for (int i = goalPrintIndex+1; i < tapes[0].dataSize(); i++) { - printAsOutput(outputFile, i, realIndex); + printAsOutput(outputFile, i, realIndex, forceRef); realIndex++; } for (int i = 0; i < goalPrintIndex; i++) { - printAsOutput(outputFile, i, realIndex); + printAsOutput(outputFile, i, realIndex, forceRef); realIndex++; } } @@ -735,14 +797,86 @@ float Config::Tape::getScore(int from, int to) return 100.0*res / (1+to-from); } -void Config::updateRawInput() +int Config::Tape::getHead() { - rawInput = ""; - auto & textTape = getTape("TEXT"); - for (int i = 0; i < textTape.size(); i++) + return head; +} + +void Config::transformSymbol(const std::string & from, const std::string & to) +{ + for (auto & tape : tapes) + for (int i = 0; i < tape.size(); i++) + if (tape.getHyp(i-tape.getHead()) == from) + tape.setHyp(i-tape.getHead(), to); +} + +void Config::setLastIndexPrinted(int lastIndexPrinted) +{ + this->lastIndexPrinted = lastIndexPrinted; +} + +void Config::setGovsAsUD(bool ref) +{ + auto & ids = getTape("ID"); + auto & govs = getTape("GOV"); + + if (ref) + for (int i = 0; i < ids.refSize(); i++) + { + try + { + int relativeIndex = std::stoi(govs.getRef(i-head)); + if (relativeIndex == 0) + continue; + auto idOfGov = ids.getRef(i+relativeIndex-head); + govs.setRef(i-head, idOfGov); + } + catch (std::exception &) {continue;} + } + else + for (int i = 0; i < ids.hypSize(); i++) + { + try + { + int relativeIndex = std::stoi(govs.getHyp(i-head)); + if (relativeIndex == 0) + continue; + auto idOfGov = ids.getHyp(i+relativeIndex-head); + govs.setHyp(i-head, idOfGov); + } + catch (std::exception &) {continue;} + } +} + +void Config::updateIdsInSequence() +{ + int sentenceEnd = getHead(); + auto & eos = getTape(ProgramParameters::sequenceDelimiterTape); + auto & ids = getTape("ID"); + while (sentenceEnd >= 0 && eos[sentenceEnd-getHead()] != ProgramParameters::sequenceDelimiter) + sentenceEnd--; + int sentenceStart = std::max(0,sentenceEnd-1); + while (sentenceStart >= 0 && eos[sentenceStart-getHead()] != ProgramParameters::sequenceDelimiter) + sentenceStart--; + sentenceStart++; + + if (sentenceEnd < 0) + { + sentenceStart = 0; + sentenceEnd = eos.hypSize()-1; + } + + int curId = 1; + for (int i = sentenceStart; i <= sentenceEnd; i++) { - if (textTape[i] != "_") - rawInput += (rawInput.empty() ? std::string("") : (choiceWithProbability(0.5) ? std::string(" ") : std::string("\n"))) + textTape[i]; + auto splited = split(ids.getRef(i-getHead()), '-'); + if (splited.size() == 1) + { + ids.setHyp(i-getHead(), std::to_string(curId++)); + continue; + } + int multiWordSize = splited.size(); + ids.setHyp(i-getHead(), std::to_string(curId)+"-"+std::to_string(curId+multiWordSize-1)); } } diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 8ba8eb23fb5f5431f73e602511d14f8629ed5b25..2242fc190e0efd7987b5f41d3105d6f7e4388377 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -195,7 +195,7 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - return (action == "WRITE b.0 POS " + c.getTape("POS").getRef(0) || c.endOfTapes()) ? 0 : 1; + return action == "WRITE b.0 POS " + c.getTape("POS").getRef(0) ? 0 : 1; }))); str2oracle.emplace("tokenizer", std::unique_ptr<Oracle>(new Oracle( @@ -263,7 +263,7 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - return (action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).getRef(0) == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")) || c.endOfTapes()) ? 0 : 1; + return action == "WRITE b.0 " + ProgramParameters::sequenceDelimiterTape + " " + (c.getTape(ProgramParameters::sequenceDelimiterTape).getRef(0) == std::string(ProgramParameters::sequenceDelimiter) ? std::string(ProgramParameters::sequenceDelimiter) : std::string("0")); }))); str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle( @@ -279,7 +279,7 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - return (action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) || c.endOfTapes()) ? 0 : 1; + return action == "WRITE b.0 MORPHO " + c.getTape("MORPHO").getRef(0) ? 0 : 1; }))); str2oracle.emplace("strategy_morpho", std::unique_ptr<Oracle>(new Oracle( @@ -387,6 +387,12 @@ void Oracle::createDatabase() movement = 1; newState = "signature"; } + + if (movement > 0 && c.endOfTapes()) + movement = 0; + + if (split(previousAction, ' ')[0] == "eos" && c.endOfTapes()) + return std::string(""); } else if (previousState == "error_parser") { @@ -475,7 +481,7 @@ void Oracle::createDatabase() while (start+c.getHead() < c.getTape("SGN").size() && !c.getTape("SGN").getHyp(start).empty()) start++; - while (end >= 0 && c.getTape("FORM").getHyp(end).empty()) + while (end >= 0 && c.getTape("FORM")[end].empty()) end--; if (start > end) @@ -568,7 +574,7 @@ void Oracle::createDatabase() const std::string & lemma = c.getTape("LEMMA").getRef(0); std::string rule = getRule(toLowerCase(form), toLowerCase(lemma)); - return (action == std::string("RULE LEMMA ON FORM ") + rule || c.endOfTapes()) ? 0 : 1; + return action == std::string("RULE LEMMA ON FORM ") + rule ? 0 : 1; }))); str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( @@ -584,26 +590,25 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - bool hasId = c.hasTape("ID"); - + auto & ids = c.getTape("ID"); auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV"); auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); int head = c.getHead(); - int stackHead = c.stackEmpty() ? 0 : c.stackTop(); - int stackGov = 0; - bool stackNoGov = false; - try - { - stackGov = stackHead + std::stoi(govs.getRef(stackHead-head)); - } catch (std::exception &){stackNoGov = true;} - int headGov = 0; - bool headNoGov = false; + bool headIsMultiword = split(ids.getRef(0), '-').size() > 1; + int headGov = -1; try {headGov = head + std::stoi(govs.getRef(0));} - catch (std::exception &){headNoGov = true;} - int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1; - int sentenceEnd = c.getHead(); + catch (std::exception &) {headGov = -1;} + + int stackHead = c.stackEmpty() ? 0 : c.stackTop(); + bool stackHeadIsMultiword = split(ids.getRef(stackHead-head), '-').size() > 1; + int stackGov = -1; + try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));} + catch (std::exception &) {stackGov = -1;} + + int sentenceStart = head-1 < 0 ? 0 : head-1; + int sentenceEnd = head; int cost = 0; @@ -620,29 +625,19 @@ void Oracle::createDatabase() if (parts[0] == "SHIFT") { - if (headNoGov) + if (headIsMultiword) return 0; - for (int i = sentenceStart; i <= sentenceEnd; i++) + for (int j = 0; j < c.stackSize(); j++) { - if (!isNum(govs.getRef(i-head))) - { - continue; - fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.getRef(i-head).c_str()); - exit(1); - } - - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} - - for (int j = 0; j < c.stackSize(); j++) + auto s = c.stackGetElem(j); + try { - auto s = c.stackGetElem(j); - if (s == i) - if (otherGov == head || headGov == s) - cost++; + int sGov = s + std::stoi(govs.getRef(s-head)); + if (sGov == head || headGov == s) + cost++; } + catch (std::exception &) {continue;} } if (c.stackSize() && stackHead == head) @@ -656,9 +651,9 @@ void Oracle::createDatabase() if (object[0] == "b") { if (parts[2] == "LABEL") - return (action == "WRITE b.0 LABEL " + c.getTape("LABEL").getRef(0) || c.endOfTapes() || c.getTape("LABEL").getRef(0) == "root") ? 0 : 1; + return (action == "WRITE b.0 LABEL " + c.getTape("LABEL").getRef(0) || c.getTape("LABEL").getRef(0) == "root") ? 0 : 1; else if (parts[2] == "GOV") - return (action == "WRITE b.0 GOV " + c.getTape("GOV").getRef(0) || c.endOfTapes()) ? 0 : 1; + return (action == ("WRITE b.0 GOV " + c.getTape("GOV").getRef(0))) ? 0 : 1; } else if (object[0] == "s") { @@ -673,18 +668,18 @@ void Oracle::createDatabase() } else if (parts[0] == "REDUCE") { - if (stackNoGov) + if (stackHeadIsMultiword) return 0; - if (stackGov == stackHead) - cost++; for (int i = head; i <= sentenceEnd; i++) { - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){continue;} - if (otherGov == stackHead) - cost++; + try + { + int otherGov = i + std::stoi(govs.getRef(i-head)); + if (otherGov == stackHead || stackGov == i) + cost++; + } + catch (std::exception &) {continue;} } if (eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter) @@ -694,22 +689,21 @@ void Oracle::createDatabase() } else if (parts[0] == "LEFT") { - if (stackNoGov) - return 0; - - if (stackGov == stackHead) - cost++; + if (stackHeadIsMultiword || headIsMultiword) + return 1; if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter) cost++; for (int i = head+1; i <= sentenceEnd; i++) { - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){continue;} - if (otherGov == stackHead || stackGov == i) - cost++; + try + { + int otherGov = i + std::stoi(govs.getRef(i-head)); + if (otherGov == stackHead || stackGov == i) + cost++; + } + catch (std::exception &) {continue;} } if (stackGov != head) @@ -725,8 +719,8 @@ void Oracle::createDatabase() } else if (parts[0] == "RIGHT") { - if (headNoGov) - return 0; + if (stackHeadIsMultiword || headIsMultiword) + return 1; for (int j = 0; j < c.stackSize(); j++) { @@ -735,17 +729,22 @@ void Oracle::createDatabase() if (s == c.stackTop()) continue; - int otherGov = 0; - try {otherGov = s + std::stoi(govs.getRef(s-head));} - catch (std::exception &){continue;} - if (otherGov == head || headGov == s) - cost++; + try + { + int otherGov = s + std::stoi(govs.getRef(s-head)); + if (otherGov == head || headGov == s) + cost++; + } + catch (std::exception &) {continue;} } for (int i = head; i <= sentenceEnd; i++) if (headGov == i) cost++; + if (headGov != stackHead) + cost++; + if (parts.size() == 1) return cost; if (labels.getRef(0) == parts[1])