diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 371947a8455b8d5a73ed22219506cf5e3cf879d0..65e07b08babbf3982d880406b75a4b5ceb7eeee4 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -95,6 +95,8 @@ std::pair<float,std::string> getClassifierAction(Config & config, Classifier::We } std::string & predictedAction = weightedActions[0].second.second; + if (predictedAction.empty()) + throw EndOfDecode(); float proba = weightedActions[0].second.first; Action * action = classifier->getAction(predictedAction); @@ -197,6 +199,7 @@ void applyActionAndTakeTransition(TransitionMachine & tm, const std::string & ac void Decoder::decode() { config.reset(); + config.fillTapesWithInput(); if (ProgramParameters::beamSize > 1) decodeBeam(); @@ -233,7 +236,7 @@ void Decoder::decodeNoBeam() std::pair<float,std::string> predictedAction; try {predictedAction = getClassifierAction(config, weightedActions, tm.getCurrentClassifier(), 0);} - catch(EndOfDecode &) {continue;} + catch(EndOfDecode &) {break;} catch(NoMoreActions &) {continue;}; checkAndRecordError(config, tm.getCurrentClassifier(), weightedActions, predictedAction.second, errors); diff --git a/maca_common/src/util.cpp b/maca_common/src/util.cpp index 9048efe3291c35ddb7cfc8b53450ffffda91a27d..20dff0dc39fa2855e95b6da275a048af76356409 100644 --- a/maca_common/src/util.cpp +++ b/maca_common/src/util.cpp @@ -571,7 +571,7 @@ std::string shrinkString(const std::string & base, int maxSize, const std::strin std::string strip(const std::string & s) { std::string res; - unsigned int i; + unsigned int i = 0; while (i < s.size() && isSeparator(s[i])) i++; while (i < s.size() && !isSeparator(s[i])) diff --git a/trainer/src/TrainInfos.cpp b/trainer/src/TrainInfos.cpp index 7cf5e946ad62f152f3715a33eab0759e85fe7269..2058d5446fe53e9fe952c02ba64e9ebabcd909b0 100644 --- a/trainer/src/TrainInfos.cpp +++ b/trainer/src/TrainInfos.cpp @@ -208,7 +208,7 @@ void TrainInfos::computeTrainScores(Config & c) else if (it.first == "Tokenizer") addTrainScore(it.first, scoresFloat["Tokens"]); else if (it.first == "Tagger") - addTrainScore(it.first, scoresFloat["XPOS"]); + addTrainScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") addTrainScore(it.first, scoresFloat["UFeats"]); else if (it.first == "Lemmatizer_Rules") @@ -264,7 +264,7 @@ void TrainInfos::computeDevScores(Config & c) else if (it.first == "Tokenizer") addDevScore(it.first, scoresFloat["Tokens"]); else if (it.first == "Tagger") - addDevScore(it.first, scoresFloat["XPOS"]); + addDevScore(it.first, scoresFloat["UPOS"]); else if (it.first == "Morpho") addDevScore(it.first, scoresFloat["UFeats"]); else if (it.first == "Lemmatizer_Rules") diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 856617ff0aed37aabfa34df9df72464b03fca989..9aafaf4e94b83783da3f9bcc0acc2e5790733e87 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -49,6 +49,7 @@ void Trainer::computeScoreOnDev() tm.reset(); devConfig->reset(); + devConfig->fillTapesWithInput(); if (ProgramParameters::debug) fprintf(stderr, "Computing score on dev set\n"); @@ -71,8 +72,11 @@ void Trainer::computeScoreOnDev() if(!tm.getCurrentClassifier()->needsTrain()) { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(*devConfig); - std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); + std::string neededActionName = tm.getCurrentClassifier()->getOracleAction(*devConfig); + + if (neededActionName.empty()) + break; + Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(tm.getCurrentClassifier()->name); @@ -207,12 +211,14 @@ void Trainer::resetAndShuffle() if(ProgramParameters::shuffleExamples) trainConfig.shuffle(); + + trainConfig.fillTapesWithInput(); } void Trainer::doStepNoTrain() { - int neededActionIndex = tm.getCurrentClassifier()->getOracleActionIndex(trainConfig); - std::string neededActionName = tm.getCurrentClassifier()->getActionName(neededActionIndex); + std::string neededActionName = tm.getCurrentClassifier()->getOracleAction(trainConfig); + if (ProgramParameters::debug) { fprintf(stderr, "Speed : %s actions/s\n", int2humanStr((int)currentSpeed).c_str()); @@ -220,6 +226,9 @@ void Trainer::doStepNoTrain() fprintf(stderr, "action=<%s>\n", neededActionName.c_str()); } + if (neededActionName.empty()) + throw EndOfIteration(); + Action * action = tm.getCurrentClassifier()->getAction(neededActionName); TransitionMachine::Transition * transition = tm.getTransition(neededActionName); action->setInfos(tm.getCurrentClassifier()->name); @@ -322,7 +331,7 @@ void Trainer::doStepTrain() fprintf(stdout, "%s\t%s\t%s\n", tm.getCurrentClassifier()->getFeatureModel()->filename.c_str(), oAction.c_str(), features.c_str()); } - if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) + if (tm.getCurrentClassifier()->isDynamic() && TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) { actionName = pAction; TI.lastActionWasPredicted[trainConfig.getCurrentStateName()] = true; @@ -539,7 +548,8 @@ void Trainer::train() tm.getCurrentClassifier()->initClassifier(trainConfig); if(!tm.getCurrentClassifier()->needsTrain()) - doStepNoTrain(); + try {doStepNoTrain();} + catch (EndOfIteration &) {break;} else try {doStepTrain();} catch (EndOfIteration &) {break;} diff --git a/transition_machine/include/BD.hpp b/transition_machine/include/BD.hpp index 7e7aea8931d640d9375037255c12d52f54b9e094..bdffaaf26d4bce8da9dbef6bd48f215957973678 100644 --- a/transition_machine/include/BD.hpp +++ b/transition_machine/include/BD.hpp @@ -35,6 +35,8 @@ class BD /// /// If this Line's values don't need to be predicted. bool isKnown; + /// @brief What column will it be in the output. + int outputIndex; /// @brief Create a new Line. /// @@ -44,7 +46,7 @@ class BD /// @param inputColumn The column of the MCD this Line corresponds to. /// @param mustPrint Whether or not this Line is part of the expected output of the program. /// @param isKnown Whether or not the entirety of this Line is already known. - Line(int num, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown); + Line(int num, int outputIndex, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown); }; private : @@ -111,6 +113,12 @@ class BD /// /// @return Whether or not this Line is part of the output. bool mustPrintLine(int index); + /// @brief Return the column index of line index in the outut. + /// + /// @param index The index of the Line. + /// + /// @return The column index of this line in the output. + int getOutputIndexOfLine(int index); /// @brief Get the name of a Line from its index. /// /// @param line The index of the Line. diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index fb134400a6145214f381f70a44b811373f96f46a..d85935bfcc6089725e95d8af7a87e8957921395f 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -62,6 +62,8 @@ class Classifier /// the correct Action to associate with a training example.\n /// For Classifier of type Information, the Oracle is used in train mode and decode mode too, it is simply a deterministic function that gives the correct Action given a Configuration. Oracle * oracle; + /// @brief Is this classifier subject to a dynamic oracle. + bool dynamic; private : @@ -238,6 +240,8 @@ class Classifier FeatureModel * getFeatureModel(); /// @brief Prepare Classifier for next iteration. void endOfIteration(); + /// @brief Is this classifier subject to a dynamic oracle. + bool isDynamic(); }; #endif diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index b4996c3d35c4725c4061b5876919409f3edd9f43..d89f93456c5d9a7b0db1a7edd2dad87bca8c2f97 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -73,6 +73,12 @@ class Config /// @param relativeIndex The index of the cell relatively to the head. /// @param elem The new content of the cell. void setHyp(int relativeIndex, const std::string & elem); + /// @brief Set the value of a cell of the ref. + /// + /// @param relativeIndex The index of the cell relatively to the head. + /// @param elem The new content of the cell. + void setRef(int relativeIndex, const std::string & elem); + void set(int relativeIndex, const std::string & elem); int getHead(); /// @brief Return true if the head of this tape is on the last cell. /// @@ -367,6 +373,10 @@ class Config std::vector< std::pair<std::string, int> > & getActionsHistory(std::string & state); void transformSymbol(const std::string & from, const std::string & to); void setLastIndexPrinted(int lastIndexPrinted); + /// @brief Transform the tape GOV from relative indexes to UD format. + void setGovsAsUD(bool ref); + /// @brief Update the IDs in the last predicted sequence. + void updateIdsInSequence(); }; #endif diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 5742db20145fa9d456063aba3af109491d0f3ff4..d68d017de9de645b65d6129c5660044cb5742099 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -163,7 +163,7 @@ Action::BasicAction ActionBank::pushHead() auto undo = [](Config & c, Action::BasicAction &) {c.stackPop();}; auto appliable = [](Config & c, Action::BasicAction &) - {return !(c.stackSize() >= ProgramParameters::maxStackSize) && (!c.endOfTapes());}; + {return !(c.stackSize() >= ProgramParameters::maxStackSize || (!c.stackEmpty() && c.stackTop() == c.getHead()));}; Action::BasicAction basicAction = {Action::BasicAction::Type::Push, "", apply, undo, appliable}; @@ -184,7 +184,7 @@ Action::BasicAction ActionBank::stackPop(bool checkGov) if (!checkGov) return true; - return !c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty(); + return split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1 || (!c.getTape("GOV").getHyp(c.stackTop()-c.getHead()).empty() && c.stackTop() != c.getHead()); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Pop, "", apply, undo, appliable}; @@ -428,6 +428,12 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na return false; int b0 = c.getHead(); int s0 = c.stackTop(); + + if (split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + return simpleBufferWriteAppliable(c, "GOV", s0-b0); }; Action::BasicAction basicAction = @@ -479,7 +485,13 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na }; auto appliable = [](Config & c, Action::BasicAction &) { - return !c.stackEmpty() && !c.endOfTapes() && simpleBufferWriteAppliable(c, "GOV", 0); + if (c.stackEmpty()) + return false; + if (split(c.getTape("ID").getRef(0), '-').size() > 1) + return false; + if (split(c.getTape("ID").getRef(c.stackTop()-c.getHead()), '-').size() > 1) + return false; + return simpleBufferWriteAppliable(c, "GOV", 0); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, "", apply, undo, appliable}; @@ -510,6 +522,7 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } else if(std::string(b1) == "EOS") { + // Puting the EOS tag on the last element of the sentence. auto apply0 = [b2](Config & c, Action::BasicAction &) { int b0 = c.getHead(); @@ -531,22 +544,27 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(basicAction0); + // Chosing root of the sentence and attaching floating words to it. auto apply = [](Config & c, Action::BasicAction & ba) { ba.data = ""; auto & govs = c.getTape("GOV"); + auto & ids = c.getTape("ID"); int b0 = c.getHead(); int rootIndex = -1; for (int i = c.stackSize()-1; i >= 0; i--) { auto s = c.stackGetElem(i); + if (split(ids.getRef(s-b0), '-').size() > 1) + continue; if (govs.getHyp(s-b0).empty() || govs.getHyp(s-b0) == "0") { if (rootIndex == -1) rootIndex = s; else { - simpleBufferWrite(c, "GOV", std::to_string(rootIndex - s), s-b0); + simpleBufferWrite(c, "GOV", std::to_string(rootIndex-s), s-b0); + simpleBufferWrite(c, "LABEL", "_", s-b0); ba.data += "+"+std::to_string(s-b0); } } @@ -565,26 +583,10 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na } simpleBufferWrite(c, "GOV", "0", rootIndex-b0); + simpleBufferWrite(c, "LABEL", "root", rootIndex-b0); // Delete the arcs from the previous sentence to the new sentence - - auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); - - for (int i = b0; i >= 0; i--) - { - if (eos[i-b0] == ProgramParameters::sequenceDelimiter) - break; - - try - { - int govIndex = i + std::stoi(govs[i-b0]); - if (govIndex <= c.stackGetElem(0)) - { - simpleBufferWrite(c, "GOV", std::to_string(rootIndex - i), i-b0); - simpleBufferWrite(c, "LABEL", "_", i-b0); - } - } catch (std::exception &) {} - } + // TODO }; auto undo = [](Config & c, Action::BasicAction & ba) { @@ -596,13 +598,17 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na if (govs.getHyp(s-b0) == "0") { simpleBufferWrite(c, "GOV", "", s-b0); + simpleBufferWrite(c, "LABEL", "", s-b0); break; } } auto deps = split(ba.data, '+'); for (auto s : deps) if (!s.empty()) + { simpleBufferWrite(c, "GOV", "", std::stoi(s)); + simpleBufferWrite(c, "LABEL", "", std::stoi(s)); + } ba.data.clear(); }; auto appliable = [](Config & c, Action::BasicAction &) @@ -614,61 +620,8 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(basicAction); - auto apply2 = [b2](Config & c, Action::BasicAction & ba) - { - ba.data = ""; - auto & labels = c.getTape("LABEL"); - int b0 = c.getHead(); - int rootIndex = -1; - - for (int i = c.stackSize()-1; i >= 0; i--) - { - auto s = c.stackGetElem(i); - if (labels.getHyp(s-b0).empty()) - { - if (rootIndex == -1) - { - rootIndex = 0; - simpleBufferWrite(c, "LABEL", "root", s-b0); - } - else - { - simpleBufferWrite(c, "LABEL", "_", s-b0); - ba.data += "+"+std::to_string(s-b0); - } - } - } - }; - auto undo2 = [](Config & c, Action::BasicAction & ba) - { - auto & labels = c.getTape("LABEL"); - int b0 = c.getHead(); - - for (int i = c.stackSize()-1; i >= 0; i--) - { - auto s = c.stackGetElem(i); - if (labels.getHyp(s-b0) == "root") - { - simpleBufferWrite(c, "LABEL", "", s-b0); - break; - } - } - auto deps = split(ba.data, '+'); - for (auto & dep : deps) - if (!dep.empty()) - simpleBufferWrite(c, "LABEL", "", std::stoi(dep)); - ba.data.clear(); - }; - auto appliable2 = [](Config & c, Action::BasicAction &) - { - return !c.stackEmpty(); - }; - Action::BasicAction basicAction2 = - {Action::BasicAction::Type::Write, "", apply2, undo2, appliable2}; - - sequence.emplace_back(basicAction2); - - auto apply4 = [b2](Config & c, Action::BasicAction & ba) + // Empty the stack. + auto apply4 = [](Config & c, Action::BasicAction & ba) { ba.data = ""; for (int i = c.stackSize()-1; i >= 0; i--) @@ -696,6 +649,23 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na {Action::BasicAction::Type::Pop, "", apply4, undo4, appliable4}; sequence.emplace_back(basicAction4); + + // Update the IDs of the words in the new sentence + auto apply5 = [](Config & c, Action::BasicAction &) + { + c.updateIdsInSequence(); + }; + auto undo5 = [](Config &, Action::BasicAction &) + { + }; + auto appliable5 = [](Config &, Action::BasicAction &) + { + return true; + }; + Action::BasicAction basicAction5 = + {Action::BasicAction::Type::Write, "", apply5, undo5, appliable5}; + + sequence.emplace_back(basicAction5); } else if(std::string(b1) == "BACK") { @@ -833,7 +803,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, int index = config.getHead() + relativeIndex; - return !(index < 0) && index < tape.size(); + return !(index < 0) && index < tape.size() && tape.getHyp(relativeIndex).empty(); } bool ActionBank::isRuleAppliable(Config & config, diff --git a/transition_machine/src/BD.cpp b/transition_machine/src/BD.cpp index b9eb9c55018c3262bddd4a67c5c2e7e5b8bfca92..42f337fe22f5319eb1a8c5dc5b393b6ba5539bad 100644 --- a/transition_machine/src/BD.cpp +++ b/transition_machine/src/BD.cpp @@ -2,9 +2,10 @@ #include "File.hpp" #include "util.hpp" -BD::Line::Line(int num, std::string name, std::string dictName, +BD::Line::Line(int num, int outputIndex, std::string name, std::string dictName, int inputColumn, bool mustPrint, bool isKnown) { + this->outputIndex = outputIndex; this->dict = dictName; this->num = num; this->name = name; @@ -52,7 +53,9 @@ BD::BD(const std::string & BDfilename, const std::string & MCDfilename) if(buffer[0] == '#') continue; - if(sscanf(buffer, "%s %s %s %s %d", name, refHyp, dict, policy, &mustPrint) != 5) + int outputIndex; + + if(sscanf(buffer, "%d %s %s %s %s %d", &outputIndex, name, refHyp, dict, policy, &mustPrint) != 6) { fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid BD line. Aborting.\n", ERRINFO, buffer); exit(1); @@ -68,7 +71,7 @@ BD::BD(const std::string & BDfilename, const std::string & MCDfilename) int inputColumn = mcdStr2Col.find(name) == mcdStr2Col.end() ? -1 : mcdStr2Col[name]; - lines.emplace_back(new Line(lines.size(), name, dict, inputColumn, mustPrint == 1, known)); + lines.emplace_back(new Line(lines.size(), outputIndex, name, dict, inputColumn, mustPrint == 1, known)); Line * line = lines.back().get(); num2line.emplace(line->num, line); name2line.emplace(line->name, line); @@ -180,3 +183,8 @@ bool BD::lineIsKnown(int line) return lines[line]->isKnown; } +int BD::getOutputIndexOfLine(int index) +{ + return lines[index]->outputIndex; +} + diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 25b16a69aa978d7b06988079b0a2d23b2f54ece3..413aac277cfca6e72e217a2622da9e28bc0d39ce 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -80,6 +80,13 @@ Classifier::Classifier(const std::string & filename, bool trainMode) int batchsizeRead = 0; if(fscanf(fd, "Batchsize : %d\n", &batchsizeRead) == 1) batchSize = batchsizeRead; + + dynamic = false; + if(fscanf(fd, "Dynamic : %[^\n]\n", buffer) == 1) + { + if (!strcmp(buffer, "yes")) + dynamic = true; + } } Classifier::Type Classifier::str2type(const std::string & s) @@ -278,6 +285,14 @@ std::vector<std::string> Classifier::getZeroCostActions(Config & config) if (a.appliable(config) && oracle->getActionCost(config, a.name) == 0) result.emplace_back(a.name); + if (ProgramParameters::debug) + { + fprintf(stderr, "Zero cost actions : "); + for (auto & s : result) + fprintf(stderr, "<%s>", s.c_str()); + fprintf(stderr, "\n"); + } + if (result.empty() && as->hasDefaultAction) if (as->getDefaultAction()->appliable(config)) result.emplace_back(as->getDefaultAction()->name); @@ -421,3 +436,8 @@ void Classifier::endOfIteration() nn->endOfIteration(); } +bool Classifier::isDynamic() +{ + return dynamic; +} + diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index a86b5db373ef3016104f6c69f54910678cfe1e6f..5bcfc8daddbdf5774788520c33db186ede7106fa 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -128,9 +128,12 @@ void Config::fillTapesWithInput() rawInput = ""; std::vector<std::string> cols; unsigned int usualColsSize = 0; + auto & ids = getTape("ID"); + auto & govs = getTape("GOV"); for (auto & sentence : inputContent) { + int sentenceStartIndex = ids.refSize(); for (unsigned int wordIndex = 0; wordIndex < sentence.size(); wordIndex++) { auto & word = sentence[wordIndex]; @@ -176,9 +179,46 @@ void Config::fillTapesWithInput() exit(1); } } - getTape(ProgramParameters::sequenceDelimiterTape).addToRef(wordIndex == sentence.size()-1 ? ProgramParameters::sequenceDelimiter : "_"); + getTape(ProgramParameters::sequenceDelimiterTape).addToRef(wordIndex == sentence.size()-1 ? ProgramParameters::sequenceDelimiter : ""); getTape(ProgramParameters::sequenceDelimiterTape).addToHyp(""); } + + for (int word = sentenceStartIndex; word < ids.refSize(); word++) + { + if (split(ids.getRef(word), '-').size() > 1) + continue; + if (govs.getRef(word) == "0") + continue; + + try + { + int id = std::stoi(ids.getRef(word)); + std::string goalId = govs.getRef(word); + int relativeIndex = 0; + + if (std::stoi(goalId) < id) + { + while (ids.getRef(word+relativeIndex) != goalId) + { + if (--relativeIndex+word < 0) + throw ""; + } + } + else + { + while (ids.getRef(word+relativeIndex) != goalId) + if (++relativeIndex+word >= ids.refSize()) + throw ""; + } + + govs.setRef(word, std::to_string(relativeIndex)); + } + catch (std::exception &) + { + fprintf(stderr, "ERROR (%s) : invalid governor '%s' '%s'. Aborting.\n", ERRINFO, govs.getRef(word).c_str(), getTape("FORM").getRef(word).c_str()); + exit(1); + } + } } // Making all tapes the same size @@ -271,15 +311,32 @@ void Config::printAsOutput(FILE * output, int dataIndex, int realIndex, bool for std::vector< std::pair<std::string, float> > toPrint; for (unsigned int j = 0; j < tapes.size(); j++) { + int outputTapeIndex = bd.getOutputIndexOfLine(j); + + while ((int)toPrint.size() < outputTapeIndex+1) + toPrint.emplace_back("_", 0.0); + if(bd.mustPrintLine(j)) { if (!forceRef) - toPrint.emplace_back(tapes[j][dataIndex-head].empty() ? "_" : tapes[j][dataIndex-head].c_str(), tapes[j].getEntropy(dataIndex-head)); + toPrint[outputTapeIndex] = {tapes[j][dataIndex-head].empty() ? "_" : tapes[j][dataIndex-head].c_str(), tapes[j].getEntropy(dataIndex-head)}; else - toPrint.emplace_back(tapes[j].getRef(dataIndex-head).empty() ? "_" : tapes[j].getRef(dataIndex-head).c_str(), tapes[j].getEntropy(dataIndex-head)); + toPrint[outputTapeIndex] = {tapes[j].getRef(dataIndex-head).empty() ? "_" : tapes[j].getRef(dataIndex-head).c_str(), tapes[j].getEntropy(dataIndex-head)}; } } + bool allEmpty = true; + + for (auto & it : toPrint) + if (it.first != "_" && !it.first.empty()) + { + allEmpty = false; + break; + } + + if (allEmpty) + return; + ProgramOutput::instance.addLine(output, toPrint, realIndex); if (!ProgramParameters::delayedOutput) @@ -297,8 +354,14 @@ void Config::moveHead(int mvt) { head += mvt; - if (hasTape("ID") && split(getTape("ID").getHyp(0), '-').size() <= 1) - currentWordIndex += mvt; + if (mvt > 0) + for (int i = 0; i < mvt; i++) + if (hasTape("ID") && split(getTape("ID").getHyp(i), '-').size() <= 1) + currentWordIndex += 1; + if (mvt < 0) + for (int i = 0; i < mvt; i++) + if (hasTape("ID") && split(getTape("ID").getHyp(-i), '-').size() <= 1) + currentWordIndex -= 1; for (auto & tape : tapes) tape.moveHead(mvt); @@ -391,6 +454,19 @@ void Config::Tape::setHyp(int relativeIndex, const std::string & elem) hyp.set(head + relativeIndex, std::pair<std::string,float>(elem,totalEntropy)); } +void Config::Tape::setRef(int relativeIndex, const std::string & elem) +{ + ref.set(head + relativeIndex, elem); +} + +void Config::Tape::set(int relativeIndex, const std::string & elem) +{ + if(isKnown) + return setRef(relativeIndex, elem); + + return setHyp(relativeIndex, elem); +} + std::string & Config::getCurrentStateName() { if(currentStateName.empty()) @@ -433,9 +509,7 @@ LimitedStack<float> & Config::getCurrentStateEntropyHistory() void Config::shuffle() { - reset(); std::random_shuffle(inputContent.begin(), inputContent.end()); - fillTapesWithInput(); } int Config::stackGetElem(int index) const @@ -558,7 +632,7 @@ bool Config::endOfTapes() const bool Config::Tape::headIsAtEnd() const { - return head == ref.getLastIndex(); + return head >= ref.getLastIndex(); } int Config::Tape::size() @@ -631,11 +705,14 @@ void Config::printTheRest(bool forceRef) if (!outputFile) return; + updateIdsInSequence(); + setGovsAsUD(forceRef); + int tapeSize = tapes[0].size(); int goalPrintIndex = lastIndexPrinted; - int realIndex = tapeSize - 1 - ((((tapes[0].dataSize()-(goalPrintIndex == -1 ? 0 : 0)))-(goalPrintIndex+1))+(goalPrintIndex)); - for (int i = goalPrintIndex+1; i < (tapes[0].dataSize()-(goalPrintIndex == -1 ? 1 : 0)); i++) + int realIndex = tapeSize - ((((tapes[0].dataSize()-(goalPrintIndex == -1 ? 0 : 0)))-(goalPrintIndex+1))+(goalPrintIndex)); + for (int i = goalPrintIndex+1; i < tapes[0].dataSize(); i++) { printAsOutput(outputFile, i, realIndex, forceRef); realIndex++; @@ -738,3 +815,68 @@ void Config::setLastIndexPrinted(int lastIndexPrinted) this->lastIndexPrinted = lastIndexPrinted; } +void Config::setGovsAsUD(bool ref) +{ + auto & ids = getTape("ID"); + auto & govs = getTape("GOV"); + + if (ref) + for (int i = 0; i < ids.refSize(); i++) + { + try + { + int relativeIndex = std::stoi(govs.getRef(i-head)); + if (relativeIndex == 0) + continue; + auto idOfGov = ids.getRef(i+relativeIndex-head); + govs.setRef(i-head, idOfGov); + } + catch (std::exception &) {continue;} + } + else + for (int i = 0; i < ids.hypSize(); i++) + { + try + { + int relativeIndex = std::stoi(govs.getHyp(i-head)); + if (relativeIndex == 0) + continue; + auto idOfGov = ids.getHyp(i+relativeIndex-head); + govs.setHyp(i-head, idOfGov); + } + catch (std::exception &) {continue;} + } +} + +void Config::updateIdsInSequence() +{ + int sentenceEnd = getHead(); + auto & eos = getTape(ProgramParameters::sequenceDelimiterTape); + auto & ids = getTape("ID"); + while (sentenceEnd >= 0 && eos[sentenceEnd-getHead()] != ProgramParameters::sequenceDelimiter) + sentenceEnd--; + int sentenceStart = std::max(0,sentenceEnd-1); + while (sentenceStart >= 0 && eos[sentenceStart-getHead()] != ProgramParameters::sequenceDelimiter) + sentenceStart--; + sentenceStart++; + + if (sentenceEnd < 0) + { + sentenceStart = 0; + sentenceEnd = eos.hypSize()-1; + } + + int curId = 1; + for (int i = sentenceStart; i <= sentenceEnd; i++) + { + auto splited = split(ids.getRef(i-getHead()), '-'); + if (splited.size() == 1) + { + ids.setHyp(i-getHead(), std::to_string(curId++)); + continue; + } + int multiWordSize = splited.size(); + ids.setHyp(i-getHead(), std::to_string(curId)+"-"+std::to_string(curId+multiWordSize-1)); + } +} + diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 8cd69d0fa57f741a7592d41b52e44bc2bcbda402..2242fc190e0efd7987b5f41d3105d6f7e4388377 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -387,6 +387,12 @@ void Oracle::createDatabase() movement = 1; newState = "signature"; } + + if (movement > 0 && c.endOfTapes()) + movement = 0; + + if (split(previousAction, ' ')[0] == "eos" && c.endOfTapes()) + return std::string(""); } else if (previousState == "error_parser") { @@ -475,7 +481,7 @@ void Oracle::createDatabase() while (start+c.getHead() < c.getTape("SGN").size() && !c.getTape("SGN").getHyp(start).empty()) start++; - while (end >= 0 && c.getTape("FORM").getHyp(end).empty()) + while (end >= 0 && c.getTape("FORM")[end].empty()) end--; if (start > end) @@ -584,26 +590,25 @@ void Oracle::createDatabase() }, [](Config & c, Oracle *, const std::string & action) { - bool hasId = c.hasTape("ID"); - + auto & ids = c.getTape("ID"); auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV"); auto & eos = c.getTape(ProgramParameters::sequenceDelimiterTape); int head = c.getHead(); - int stackHead = c.stackEmpty() ? 0 : c.stackTop(); - int stackGov = 0; - bool stackNoGov = false; - try - { - stackGov = stackHead + std::stoi(govs.getRef(stackHead-head)); - } catch (std::exception &){stackNoGov = true;} - int headGov = 0; - bool headNoGov = false; + bool headIsMultiword = split(ids.getRef(0), '-').size() > 1; + int headGov = -1; try {headGov = head + std::stoi(govs.getRef(0));} - catch (std::exception &){headNoGov = true;} - int sentenceStart = c.getHead()-1 < 0 ? 0 : c.getHead()-1; - int sentenceEnd = c.getHead(); + catch (std::exception &) {headGov = -1;} + + int stackHead = c.stackEmpty() ? 0 : c.stackTop(); + bool stackHeadIsMultiword = split(ids.getRef(stackHead-head), '-').size() > 1; + int stackGov = -1; + try {stackGov = stackHead + std::stoi(govs.getRef(stackHead-head));} + catch (std::exception &) {stackGov = -1;} + + int sentenceStart = head-1 < 0 ? 0 : head-1; + int sentenceEnd = head; int cost = 0; @@ -620,29 +625,19 @@ void Oracle::createDatabase() if (parts[0] == "SHIFT") { - if (headNoGov) + if (headIsMultiword) return 0; - for (int i = sentenceStart; i <= sentenceEnd; i++) + for (int j = 0; j < c.stackSize(); j++) { - if (!isNum(govs.getRef(i-head))) - { - continue; - fprintf(stderr, "ERROR (%s) : govs.ref[%d] = <%s>. Aborting.\n", ERRINFO, i, govs.getRef(i-head).c_str()); - exit(1); - } - - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){fprintf(stderr, "ERROR (%s) : aborting.\n", ERRINFO); exit(1);} - - for (int j = 0; j < c.stackSize(); j++) + auto s = c.stackGetElem(j); + try { - auto s = c.stackGetElem(j); - if (s == i) - if (otherGov == head || headGov == s) - cost++; + int sGov = s + std::stoi(govs.getRef(s-head)); + if (sGov == head || headGov == s) + cost++; } + catch (std::exception &) {continue;} } if (c.stackSize() && stackHead == head) @@ -673,18 +668,18 @@ void Oracle::createDatabase() } else if (parts[0] == "REDUCE") { - if (stackNoGov) + if (stackHeadIsMultiword) return 0; - if (stackGov == stackHead) - cost++; for (int i = head; i <= sentenceEnd; i++) { - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){continue;} - if (otherGov == stackHead) - cost++; + try + { + int otherGov = i + std::stoi(govs.getRef(i-head)); + if (otherGov == stackHead || stackGov == i) + cost++; + } + catch (std::exception &) {continue;} } if (eos.getRef(stackHead-head) != ProgramParameters::sequenceDelimiter) @@ -694,22 +689,21 @@ void Oracle::createDatabase() } else if (parts[0] == "LEFT") { - if (stackNoGov) - return 0; - - if (stackGov == stackHead) - cost++; + if (stackHeadIsMultiword || headIsMultiword) + return 1; if (eos.getRef(stackHead-head) == ProgramParameters::sequenceDelimiter) cost++; for (int i = head+1; i <= sentenceEnd; i++) { - int otherGov = 0; - try {otherGov = i + std::stoi(govs.getRef(i-head));} - catch (std::exception &){continue;} - if (otherGov == stackHead || stackGov == i) - cost++; + try + { + int otherGov = i + std::stoi(govs.getRef(i-head)); + if (otherGov == stackHead || stackGov == i) + cost++; + } + catch (std::exception &) {continue;} } if (stackGov != head) @@ -725,8 +719,8 @@ void Oracle::createDatabase() } else if (parts[0] == "RIGHT") { - if (headNoGov) - return 0; + if (stackHeadIsMultiword || headIsMultiword) + return 1; for (int j = 0; j < c.stackSize(); j++) { @@ -735,17 +729,22 @@ void Oracle::createDatabase() if (s == c.stackTop()) continue; - int otherGov = 0; - try {otherGov = s + std::stoi(govs.getRef(s-head));} - catch (std::exception &){continue;} - if (otherGov == head || headGov == s) - cost++; + try + { + int otherGov = s + std::stoi(govs.getRef(s-head)); + if (otherGov == head || headGov == s) + cost++; + } + catch (std::exception &) {continue;} } for (int i = head; i <= sentenceEnd; i++) if (headGov == i) cost++; + if (headGov != stackHead) + cost++; + if (parts.size() == 1) return cost; if (labels.getRef(0) == parts[1])