diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp index 974e0c52963bcb49b7bba28c583d385e50c669fe..5c419724bd908267a28cb7a684757ea368cc96f6 100644 --- a/MLP/include/MLP.hpp +++ b/MLP/include/MLP.hpp @@ -192,19 +192,11 @@ class MLP /// @return A vector containing one score per possible class. std::vector<float> predict(FeatureModel::FeatureDescription & fd); - /// @brief Train the MLP on these TrainingExamples. + /// @brief Update the parameters according to the given gold class. /// - /// @param examples A batch of training examples. - /// - /// @return The number of these training examples correctly classified by the MLP. - int trainOnBatch(TrainingExamples & examples); - - /// @brief Predict the class of training examples. - /// - /// @param examples The training examples. - /// - /// @return The number of these training examples correctly classified by the MLP. - int getScoreOnBatch(TrainingExamples & examples); + /// @param fd The input to use. + /// @param gold The gold class of this input. + void update(FeatureModel::FeatureDescription & fd, int gold); /// @brief Save the MLP to a file. /// diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp index 6fe13b37278781222ec31baa0449102c4e12b4a3..c7c3c9891827a513554778346f59a18c9b0b627e 100644 --- a/MLP/src/MLP.cpp +++ b/MLP/src/MLP.cpp @@ -174,6 +174,23 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd) return as_vector(cg.forward(output)); } +void MLP::update(FeatureModel::FeatureDescription & fd, int gold) +{ + dynet::ComputationGraph cg; + + std::vector<dynet::Expression> expressions; + + for (auto & featValue : fd.values) + expressions.emplace_back(featValue2Expression(cg, featValue)); + + dynet::Expression input = dynet::concatenate(expressions); + dynet::Expression output = run(cg, input); + dynet::Expression loss = pickneglogsoftmax(output, gold); + + cg.backward(loss); + trainer->update(); +} + dynet::DynetParams & MLP::getDefaultParams() { static dynet::DynetParams params; @@ -300,117 +317,6 @@ void MLP::printParameters(FILE * output) fprintf(output, "Parameters : NOT IMPLEMENTED\n"); } -int MLP::trainOnBatch(TrainingExamples & examples) -{ - dynet::ComputationGraph cg; - std::vector<dynet::Expression> inputs; - std::vector<unsigned int> goldClasses; - int inputDim = 0; - int outputDim = layers.back().output_dim; - - for(unsigned int i = 0; i < examples.size(); i++) - { - int index = examples.order[i]; - auto & example = examples.examples[index]; - - std::vector<dynet::Expression> expressions; - expressions.clear(); - - for (auto & featValue : example.values) - expressions.emplace_back(featValue2Expression(cg, featValue)); - - inputs.emplace_back(dynet::concatenate(expressions)); - inputDim = inputs.back().dim().rows(); - goldClasses.emplace_back((unsigned)examples.classes[index]); - } - - dynet::Expression concatenation = dynet::concatenate(inputs); - int batchSize = examples.size(); - - dynet::Expression batchedInput = reshape((concatenation), - dynet::Dim({(unsigned)inputDim}, batchSize)); - - dynet::Expression output = run(cg, batchedInput); - - if(trainMode) - { - dynet::Expression batchedLoss = pickneglogsoftmax(output, goldClasses); - dynet::Expression loss = sum_batches(batchedLoss); - cg.backward(loss); - trainer->update(); - } - - int nbCorrect = 0; - std::vector<float> predictions = as_vector(output.value()); - for (unsigned int i = 0; (int)i < batchSize; i++) - { - int prediction = 0; - - for (unsigned int j = 0; (int)j < outputDim; j++) - if(predictions[i*outputDim+j] > predictions[i*outputDim+prediction]) - prediction = (int)j; - - if(prediction == (int)goldClasses[i]) - nbCorrect++; - } - - return nbCorrect; -} - -int MLP::getScoreOnBatch(TrainingExamples & examples) -{ - bool currentDropoutActive = dropoutActive; - dropoutActive = false; - - dynet::ComputationGraph cg; - std::vector<dynet::Expression> inputs; - std::vector<unsigned int> goldClasses; - int inputDim = 0; - int outputDim = layers.back().output_dim; - - for(unsigned int i = 0; i < examples.size(); i++) - { - int index = examples.order[i]; - auto & example = examples.examples[index]; - - std::vector<dynet::Expression> expressions; - expressions.clear(); - - for (auto & featValue : example.values) - expressions.emplace_back(featValue2Expression(cg, featValue)); - - inputs.emplace_back(dynet::concatenate(expressions)); - inputDim = inputs.back().dim().rows(); - goldClasses.emplace_back((unsigned)examples.classes[index]); - } - - dynet::Expression concatenation = dynet::concatenate(inputs); - int batchSize = examples.size(); - - dynet::Expression batchedInput = reshape((concatenation), - dynet::Dim({(unsigned)inputDim}, batchSize)); - - dynet::Expression output = run(cg, batchedInput); - - int nbCorrect = 0; - std::vector<float> predictions = as_vector(output.value()); - for (unsigned int i = 0; (int)i < batchSize; i++) - { - int prediction = 0; - - for (unsigned int j = 0; (int)j < outputDim; j++) - if(predictions[i*outputDim+j] > predictions[i*outputDim+prediction]) - prediction = (int)j; - - if(prediction == (int)goldClasses[i]) - nbCorrect++; - } - - dropoutActive = currentDropoutActive; - - return nbCorrect; -} - void MLP::save(const std::string & filename) { saveStruct(filename); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index e31b82be60361858c814bfc801b878840c6dfef7..0f66aeb3b8b6025d01a428505ddabf9ddb453fd0 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -43,6 +43,24 @@ class Trainer /// @brief If true, duplicates examples will be removed from the training set. bool removeDuplicates; + /// @brief For each classifier, a pair of number examples seen / number examples successfully classified + std::map< std::string, std::pair<int, int> > trainCounter; + + /// @brief For each classifier, the train score for the current iteration. + std::map< std::string, float > scores; + + /// @brief For each classifier, the best score seen on dev. + std::map< std::string, float > bestScores; + + /// @brief Whether or not each Classifier topology has been printed. + std::map< std::string, bool > topologyPrinted; + + /// @brief Current iteration. + int curIter; + + /// @brief Number of iterations. + int nbIter; + public : /// @brief The FeatureDescritpion of a Config. @@ -50,18 +68,6 @@ class Trainer private : - /// @brief Train the TransitionMachine using batches of examples. - /// - /// For each epoch all the Classifier of the TransitionMachine are fed all the - /// training examples, at the end of the epoch Classifier are evaluated on - /// the devBD if available, and each Classifier will be saved only if its score - /// on the current epoch is its all time best.\n - /// When a Classifier is saved that way, all the Dict involved are also saved. - /// @param nbIter The number of epochs. - /// @param batchSize The size of each batch (in number of examples). - /// @param mustShuffle Will the examples be shuffled after every epoch ? - void trainBatched(int nbIter, int batchSize, bool mustShuffle); - /// @brief Train the TransitionMachine one example at a time. /// /// For each epoch all the Classifier of the TransitionMachine are fed all the @@ -70,49 +76,16 @@ class Trainer /// on the current epoch is its all time best.\n /// When a Classifier is saved that way, all the Dict involved are also saved. /// @param nbIter The number of epochs. - void trainUnbatched(int nbIter); + /// @param mustShuffle Whether or not to shuffle examples between each epoch. + void trainUnbatched(int nbIter, bool mustShuffle); - /// @brief Uses a TM and a config to create the TrainingExamples that will be used during training. - /// - /// @param config The config to use. - /// - /// @return For each classifier, a set of training examples. - std::map<Classifier*,TrainingExamples> getExamplesByClassifier(Config & config); + /// @brief Compute and print scores for each Classifier on this epoch, and save the Classifier if they achieved their all time best score. + void printScoresAndSave(FILE * output); - /// @brief Make each Classifier go over every examples. - /// - /// Depending on getScoreOnBatch, it can update the parameters or not. - /// @param examples Map each trainable Classifier with a set of examples. - /// @param batchSize The batch size to use. - /// @param nbExamples Map each trainable Classifier to a count of how many examples it has seen during this epoch and a count of how many of this examples it has correctly classified. This map is filled by this function. - /// @param getScoreOnBatch The MLP function that must be called to get the score of a classifier on a certain batch. -void processAllExamples( - std::map<Classifier*, TrainingExamples> & examples, - int batchSize, std::map< std::string, std::pair<int, int> > & nbExamples, - std::function<int(Classifier *, TrainingExamples &)> getScoreOnBatch); - - /// @brief Print the score obtained by all Classifier on this epoch. - /// - /// @param output Where to print the output. - /// @param nbExamplesTrain Map each trainable Classifier to a count of how many train examples it has seen during this epoch and a count of how many of this examples it has correctly classified. - /// @param nbExamplesDev Map each trainable Classifier to a count of how many dev examples it has seen during this epoch and a count of how many of this examples it has correctly classified. - /// @param trainScores The scores obtained by each Classifier on the train set. - /// @param devScores The scores obtained by each Classifier on the train set. - /// @param bestIter Map each classifier to its best epoch. It is updated by this function. - /// @param nbIter The total number of epoch of the training. - /// @param curIter The current epoch of the training. - void printIterationScores(FILE * output, - std::map< std::string, std::pair<int, int> > & nbExamplesTrain, - std::map< std::string, std::pair<int, int> > & nbExamplesDev, - std::map< std::string, std::vector<float> > & trainScores, - std::map< std::string, std::vector<float> > & devScores, - std::map<std::string, int> & bestIter, - int nbIter, int curIter); - - /// @brief For every Classifier, shuffle its training examples. + /// @brief Get the scores of the classifiers on the dev dataset. /// - /// @param examples Map each Classifier to a set of training examples. - void shuffleAllExamples(std::map<Classifier*,TrainingExamples> & examples); + /// @return Map from each Classifier name to their score. + std::map<std::string, float> getScoreOnDev(); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 56b24ce70b60bd08ae6d0837673f1ae56913ef88..0b6f5a399eb90e457bd11be915260c05528b9bb1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -16,221 +16,92 @@ Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, C this->removeDuplicates = removeDuplicates; } -std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config) +std::map<std::string, float> Trainer::getScoreOnDev() { - std::map<Classifier*, TrainingExamples> examples; + if (!devConfig) + return {}; - while (!config.isFinal()) + tm.reset(); + devConfig->reset(); + + std::map< std::string, std::pair<int, int> > counts; + + while (!devConfig->isFinal()) { TransitionMachine::State * currentState = tm.getCurrentState(); Classifier * classifier = currentState->classifier; - config.setCurrentStateName(¤tState->name); + devConfig->setCurrentStateName(¤tState->name); Dict::currentClassifierName = classifier->name; - classifier->initClassifier(config); + classifier->initClassifier(*devConfig); - if (debugMode) + if(!classifier->needsTrain()) { - config.printForDebug(stderr); - fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); - } - - int neededActionIndex = classifier->getOracleActionIndex(config); - std::string neededActionName = classifier->getActionName(neededActionIndex); + int neededActionIndex = classifier->getOracleActionIndex(*devConfig); + std::string neededActionName = classifier->getActionName(neededActionIndex); + Action * action = classifier->getAction(neededActionName); - if (debugMode) - { - fprintf(stderr, "Action : %s\n", neededActionName.c_str()); - fprintf(stderr, "\n"); + action->apply(*devConfig); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + tm.takeTransition(transition); + devConfig->moveHead(transition->headMvt); } - - if(classifier->needsTrain()) - examples[classifier].add(classifier->getFeatureDescription(config), neededActionIndex); - - Action * action = classifier->getAction(neededActionName); - if(!action->appliable(config)) - fprintf(stderr, "WARNING (%s) : action \'%s\' is not appliable.\n", ERRINFO, neededActionName.c_str()); - action->apply(config); - - TransitionMachine::Transition * transition = tm.getTransition(neededActionName); - tm.takeTransition(transition); - config.moveHead(transition->headMvt); - } - - if (removeDuplicates) - for (auto & it : examples) - it.second.removeDuplicates(); - - return examples; -} - -void Trainer::processAllExamples( - std::map<Classifier*, TrainingExamples> & examples, - int batchSize, std::map< std::string, std::pair<int, int> > & nbExamples, - std::function<int(Classifier *, TrainingExamples &)> getScoreOnBatch) -{ - for(auto & it : examples) - { - while(true) + else { - TrainingExamples batch = it.second.getBatch(batchSize); - - if (batch.size() == 0) - break; + auto weightedActions = classifier->weightActions(*devConfig); + std::string pAction = ""; - int nbCorrects = getScoreOnBatch(it.first, batch); + for (auto & it : weightedActions) + if (it.first) + { + pAction = it.second.second; + break; + } - nbExamples[it.first->name].first += batch.size(); - nbExamples[it.first->name].second += nbCorrects; - } + auto zeroCostActions = classifier->getZeroCostActions(*devConfig); - it.second.reset(); - } -} - -void Trainer::printIterationScores(FILE * output, - std::map< std::string, std::pair<int, int> > & nbExamplesTrain, - std::map< std::string, std::pair<int, int> > & nbExamplesDev, - std::map< std::string, std::vector<float> > & trainScores, - std::map< std::string, std::vector<float> > & devScores, - std::map<std::string, int> & bestIter, - int nbIter, int curIter) -{ - std::vector<std::string> names; - std::vector<std::string> acc; - std::vector<std::string> train; - std::vector<std::string> dev; - std::vector<std::string> saved; + bool pActionIsZeroCost = false; + for (auto & s : zeroCostActions) + if (s == pAction) + { + pActionIsZeroCost = true; + break; + } - fprintf(output, "Iteration %d/%d :\n", curIter+1, nbIter); - for(auto & it : nbExamplesTrain) - { - float scoreTrain = 100.0*it.second.second / it.second.first; - float scoreDev = devConfig ? 100.0*nbExamplesDev[it.first].second / nbExamplesDev[it.first].first : -1.0; + counts[classifier->name].first++; + counts[classifier->name].second += pActionIsZeroCost ? 1 : 0; - trainScores[it.first].emplace_back(scoreTrain); - devScores[it.first].emplace_back(scoreDev); + std::string actionName = pAction; + Action * action = classifier->getAction(actionName); - bool isBest = curIter ? false : true; - if (devConfig) - { - if (scoreDev > devScores[it.first][bestIter[it.first]]) - { - isBest = true; - bestIter[it.first] = devScores[it.first].size()-1; - } - } - else - { - if (scoreTrain > trainScores[it.first][bestIter[it.first]]) - { - isBest = true; - bestIter[it.first] = trainScores[it.first].size()-1; - } + action->apply(*devConfig); + TransitionMachine::Transition * transition = tm.getTransition(actionName); + tm.takeTransition(transition); + devConfig->moveHead(transition->headMvt); } - - names.emplace_back(it.first); - acc.emplace_back("accuracy"); - train.emplace_back(": train(" + float2str(scoreTrain, "%.2f") + "%)"); - dev.emplace_back(devConfig ? "dev(" + float2str(scoreDev, "%.2f") + "%)" : ""); - saved.emplace_back(isBest ? "SAVED" : ""); } - printColumns(output, {names, acc, train, dev, saved}); -} + std::map<std::string, float> scores; + for (auto & it : counts) + scores[it.first] = 100.0 * it.second.second / it.second.first; -void Trainer::shuffleAllExamples(std::map<Classifier*, TrainingExamples> & examples) -{ - for (auto & it : examples) - it.second.shuffle(); + return scores; } -void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) +void Trainer::trainUnbatched(int nbIter, bool mustShuffle) { - std::map<Classifier*,TrainingExamples> trainExamples; - std::map<Classifier*,TrainingExamples> devExamples; - - fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); - - trainExamples = getExamplesByClassifier(trainConfig); - - tm.reset(); - - if(devBD && devConfig) - devExamples = getExamplesByClassifier(*devConfig); - - auto & classifiers = tm.getClassifiers(); - for(Classifier * cla : classifiers) - if(cla->needsTrain()) - cla->printTopology(stderr); - - std::map< std::string, std::vector<float> > trainScores; - std::map< std::string, std::vector<float> > devScores; - std::map<std::string, int> bestIter; + this->nbIter = nbIter; Dict::saveDicts(expPath, ""); - for (int i = 0; i < nbIter; i++) - { - std::map< std::string, std::pair<int, int> > nbExamplesTrain; - std::map< std::string, std::pair<int, int> > nbExamplesDev; - - if(mustShuffle) - shuffleAllExamples(trainExamples); - - processAllExamples(trainExamples, batchSize, nbExamplesTrain, - [](Classifier * c, TrainingExamples & ex) - { - return c->trainOnBatch(ex); - }); - - processAllExamples(devExamples, batchSize, nbExamplesDev, - [](Classifier * c, TrainingExamples & ex) - { - return c->getScoreOnBatch(ex); - }); - - printIterationScores(stderr, nbExamplesTrain, nbExamplesDev, - trainScores, devScores, bestIter, nbIter, i); - - for(Classifier * cla : classifiers) - if(cla->needsTrain()) - if(bestIter[cla->name] == i) - { - cla->save(expPath + cla->name + ".model"); - Dict::saveDicts(expPath, cla->name); - } - } -} - -void Trainer::trainUnbatched(int nbIter) -{ - std::map<Classifier*,TrainingExamples> devExamples; - fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); - if(devBD && devConfig) - devExamples = getExamplesByClassifier(*devConfig); - - auto & classifiers = tm.getClassifiers(); - for(Classifier * cla : classifiers) - if(cla->needsTrain()) - cla->printTopology(stderr); - - std::map< std::string, std::vector<float> > trainScores; - std::map< std::string, std::vector<float> > devScores; - std::map<std::string, int> bestIter; - - Dict::saveDicts(expPath, ""); - - for (int i = 0; i < nbIter; i++) + for (curIter = 0; curIter < nbIter; curIter++) { tm.reset(); trainConfig.reset(); - std::map< std::string, std::pair<int, int> > nbExamplesTrain; - std::map< std::string, std::pair<int, int> > nbExamplesDev; - - int nbTreated = 0; + if(mustShuffle) + trainConfig.shuffle("EOS", "1"); while (!trainConfig.isFinal()) { @@ -240,134 +111,159 @@ void Trainer::trainUnbatched(int nbIter) Dict::currentClassifierName = classifier->name; classifier->initClassifier(trainConfig); - if (debugMode) + if(!classifier->needsTrain()) { - trainConfig.printForDebug(stderr); - fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + int neededActionIndex = classifier->getOracleActionIndex(trainConfig); + std::string neededActionName = classifier->getActionName(neededActionIndex); + Action * action = classifier->getAction(neededActionName); + + action->apply(trainConfig); + TransitionMachine::Transition * transition = tm.getTransition(neededActionName); + tm.takeTransition(transition); + trainConfig.moveHead(transition->headMvt); } - - int neededActionIndex = classifier->getOracleActionIndex(trainConfig); - std::string neededActionName = classifier->getActionName(neededActionIndex); - - if (debugMode) + else { - fprintf(stderr, "Action : %s\n", neededActionName.c_str()); - fprintf(stderr, "\n"); - } + if (!topologyPrinted.count(classifier->name)) + { + topologyPrinted[classifier->name] = true; + classifier->printTopology(stderr); + } - if(classifier->needsTrain()) - { - TrainingExamples example; - example.add(classifier->getFeatureDescription(trainConfig), neededActionIndex); - int score = classifier->trainOnBatch(example); - nbExamplesTrain[classifier->name].first++; - nbExamplesTrain[classifier->name].second += score; - } + auto weightedActions = classifier->weightActions(trainConfig); + std::string pAction = ""; + std::string oAction = ""; - auto weightedActions = classifier->weightActions(trainConfig); + auto zeroCostActions = classifier->getZeroCostActions(trainConfig); + bool pActionIsZeroCost = false; - if (debugMode) - { - Classifier::printWeightedActions(stderr, weightedActions); - fprintf(stderr, "\n"); - } + if (zeroCostActions.empty()) + { + fprintf(stderr, "ERROR (%s) : Unable to find any zero cost action. Aborting.\n", ERRINFO); + trainConfig.printForDebug(stderr); + exit(1); + } - std::string & predictedAction = weightedActions[0].second.second; - Action * action = classifier->getAction(predictedAction); + for (auto & it : weightedActions) + if (it.first) + { + if (pAction == "") + pAction = it.second.second; - for(unsigned int i = 0; i < weightedActions.size(); i++) - { - predictedAction = weightedActions[i].second.second; - action = classifier->getAction(predictedAction); + for (auto & z : zeroCostActions) + { + if (pAction == z) + pActionIsZeroCost = true; - if(weightedActions[i].first) - break; - } + if (oAction == "" && z == it.second.second) + oAction = it.second.second; + } - if(!action->appliable(trainConfig)) - { - fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str()); - exit(1); - } + if (pAction != "" && oAction != "") + break; + } - if (nbTreated % 1000 == 0) - fprintf(stderr, "%d - %s\n", nbTreated, predictedAction.c_str()); + classifier->trainOnExample(trainConfig, classifier->getActionIndex(oAction)); - nbTreated++; + trainCounter[classifier->name].first++; + trainCounter[classifier->name].second += pActionIsZeroCost ? 1 : 0; - action->apply(trainConfig); + std::string actionName = ""; + if (pActionIsZeroCost) + actionName = pAction; + else + actionName = zeroCostActions[rand() % zeroCostActions.size()]; - TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + if (debugMode) + { + trainConfig.printForDebug(stderr); + fprintf(stderr, "pAction=<%s> oAction=<%s> nb=%lu action=<%s>\n", pAction.c_str(), oAction.c_str(), zeroCostActions.size(), actionName.c_str()); + } - tm.takeTransition(transition); - trainConfig.moveHead(transition->headMvt); - } + Action * action = classifier->getAction(actionName); - devConfig->reset(); - tm.reset(); - while (!devConfig->isFinal()) - { - TransitionMachine::State * currentState = tm.getCurrentState(); - Classifier * classifier = currentState->classifier; - devConfig->setCurrentStateName(¤tState->name); - Dict::currentClassifierName = classifier->name; - classifier->initClassifier(*devConfig); + action->apply(trainConfig); + TransitionMachine::Transition * transition = tm.getTransition(actionName); + tm.takeTransition(transition); + trainConfig.moveHead(transition->headMvt); + } + } - int neededActionIndex = classifier->getOracleActionIndex(*devConfig); - std::string neededActionName = classifier->getActionName(neededActionIndex); + printScoresAndSave(stderr); + } +} - auto weightedActions = classifier->weightActions(*devConfig); +void Trainer::printScoresAndSave(FILE * output) +{ + for (auto & it : trainCounter) + scores[it.first] = 100.0 * it.second.second / it.second.first; - std::string & predictedAction = weightedActions[0].second.second; - Action * action = classifier->getAction(predictedAction); + std::vector<std::string> names; + std::vector<std::string> acc; + std::vector<std::string> train; + std::vector<std::string> dev; + std::vector<std::string> savedStr; - for(unsigned int i = 0; i < weightedActions.size(); i++) - { - predictedAction = weightedActions[i].second.second; - action = classifier->getAction(predictedAction); + std::map<std::string, bool> saved; - if(weightedActions[i].first) - break; - } + auto devScores = getScoreOnDev(); - if(!action->appliable(trainConfig)) + if (devConfig) + { + for (auto & it : devScores) + { + if (bestScores.count(it.first) == 0 || bestScores[it.first] < it.second) { - fprintf(stderr, "ERROR (%s) : action \'%s\' is not appliable. Aborting\n", ERRINFO, predictedAction.c_str()); - exit(1); + bestScores[it.first] = it.second; + saved[it.first] = true; } - - if(classifier->needsTrain()) + else + saved[it.first] = false; + } + } + else + { + for (auto & it : scores) + { + if (bestScores.count(it.first) == 0 || bestScores[it.first] < it.second) { - nbExamplesDev[classifier->name].first++; - nbExamplesDev[classifier->name].second += neededActionName == predictedAction ? 1 : 0; + bestScores[it.first] = it.second; + saved[it.first] = true; } + else + saved[it.first] = false; + } + } - action->apply(*devConfig); - - TransitionMachine::Transition * transition = tm.getTransition(predictedAction); + auto classifiers = tm.getClassifiers(); + for (auto * cla : classifiers) + { + if (!saved.count(cla->name)) + continue; - tm.takeTransition(transition); - devConfig->moveHead(transition->headMvt); + if (saved[cla->name]) + { + cla->save(expPath + cla->name + ".model"); + Dict::saveDicts(expPath, cla->name); } + } - printIterationScores(stderr, nbExamplesTrain, nbExamplesDev, - trainScores, devScores, bestIter, nbIter, i); - - for(Classifier * cla : classifiers) - if(cla->needsTrain()) - if(bestIter[cla->name] == i) - { - cla->save(expPath + cla->name + ".model"); - Dict::saveDicts(expPath, cla->name); - } + for (auto & it : saved) + { + names.emplace_back(it.first); + acc.emplace_back("accuracy"); + train.emplace_back(": train(" + float2str(scores[it.first], "%.2f") + "%)"); + dev.emplace_back(devConfig ? "dev(" +float2str(devScores[it.first], "%.2f") + "%)" : ""); + savedStr.emplace_back(saved[it.first] ? "SAVED" : ""); } + + fprintf(output, "Iteration %d/%d :\n", curIter+1, nbIter); + + printColumns(output, {names, acc, train, dev, savedStr}); } void Trainer::train(int nbIter, int batchSize, bool mustShuffle, bool batched) { - if (batched) - trainBatched(nbIter, batchSize, mustShuffle); - else - trainUnbatched(nbIter); + trainUnbatched(nbIter, mustShuffle); } diff --git a/transition_machine/include/Classifier.hpp b/transition_machine/include/Classifier.hpp index 0cb94fa5aab951cb1b9252ea5c2d72a30abf0128..2b0319e27e44d448a3cc3ef6136038eb3ebeb73a 100644 --- a/transition_machine/include/Classifier.hpp +++ b/transition_machine/include/Classifier.hpp @@ -100,24 +100,23 @@ class Classifier /// /// @return The name of the correct Action to take. std::string getOracleAction(Config & config); + /// @brief Get all the actions that are zero cost given this Config. + /// + /// @param config The current Config. + /// + /// @return A vector of all the actions that are zero cost given this Config. + std::vector<std::string> getZeroCostActions(Config & config); /// @brief Use the Oracle on the config to get the correct Action to take. /// /// @param config The Config to work with. /// /// @return The index of the correct Action to take. int getOracleActionIndex(Config & config); - /// @brief Predict the classes of these training examples. - /// - /// @param examples A set of training examples. - /// - /// @return The number of these training examples whose class has correctly been predicted. - int getScoreOnBatch(TrainingExamples & examples); - /// @brief Train this classifier of these TrainingExamples. + /// @brief Train the classifier on a training example. /// - /// @param examples a batch of training examples. - /// - /// @return The number of these training examples whose class has correctly been predicted. - int trainOnBatch(TrainingExamples & examples); + /// @param config The Config to work with. + /// @param gold The gold class of the Config. + void trainOnExample(Config & config, int gold); /// @brief Get the name of an Action from its index. /// /// The index of an Action can be seen as the index of the corresponding output neuron in the underlying neural network. @@ -149,6 +148,12 @@ class Classifier /// /// @param output Where to print. void printTopology(FILE * output); + /// @brief Return the index of the Action. + /// + /// @param action The action. + /// + /// @return The index of the Action. + int getActionIndex(const std::string & action); }; #endif diff --git a/transition_machine/include/Config.hpp b/transition_machine/include/Config.hpp index 7f0773be69924142d1e796234a22104da61bd39a..7913b743792dd2dbcd966fea45f6b46ba134fd8f 100644 --- a/transition_machine/include/Config.hpp +++ b/transition_machine/include/Config.hpp @@ -116,6 +116,14 @@ class Config /// /// @return The history of Action of the current state in the TransitionMachine. std::vector<std::string> & getCurrentStateHistory(); + + /// @brief Shuffle the segments of the Config. + /// + /// For instance if you call shuffle("EOS", "1");\n + /// Sentences will be preserved, but their order will be shuffled. + /// @param delimiterTape The tape containing the delimiters of segments. + /// @param delimiter The delimiters of segments. + void shuffle(const std::string & delimiterTape, const std::string & delimiter); }; #endif diff --git a/transition_machine/include/Oracle.hpp b/transition_machine/include/Oracle.hpp index a3353d70c12a374c5fb42f51598b878810e75111..1423f031f277efa207af10527aeb3a02d0d832d5 100644 --- a/transition_machine/include/Oracle.hpp +++ b/transition_machine/include/Oracle.hpp @@ -23,13 +23,17 @@ class Oracle /// @brief Construct a new Oracle. /// /// @param initialize The function that will be called at the start of the program, to initialize the Oracle. - /// @param findAction The function that will give the correct Action to take, given the current Config. + /// @param findAction The function that will return the optimal action to take given the Config, for classifiers that do not require training. + /// @param isZeroCost The function that will return true if the given action is optimal for the given Config. Oracle(std::function<void(Oracle *)> initialize, - std::function<std::string(Config &, Oracle *)> findAction); + std::function<std::string(Config &, Oracle *)> findAction, + std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost); private : - /// @brief The function that will give the correct Action to take, given the current Config. + /// @brief Return true if the given action is optimal for the given Config. + std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost; + /// @brief Return the optimal action to take, only for non-trainable Classifier. std::function<std::string(Config &, Oracle *)> findAction; /// @brief The function that will be called at the start of the program, to initialize the Oracle. std::function<void(Oracle *)> initialize; @@ -66,11 +70,18 @@ class Oracle public : - /// @brief Get the correct Action to take, given the current Config. + /// @brief Tests whether or not the action is optimal for the given Config. /// /// @param config The current Config. + /// @param action The action to test. /// - /// @return The name of the correct Action to take. + /// @return Whether or not the action is optimal for the given Config. + bool actionIsZeroCost(Config & config, const std::string & action); + /// @brief Get the optimal action given the current Config, only for non-trainable Classifier.. + /// + /// @param config The current Config. + /// + /// @return The optimal action to take given the current Config. std::string getAction(Config & config); /// @brief Initialize this oracle, by calling initialize. void init(); diff --git a/transition_machine/src/Action.cpp b/transition_machine/src/Action.cpp index 783756ae797e2068fe7aaef01654f08a294505eb..2f47ca441f11f4a5c77d32de21971a9987966d9a 100644 --- a/transition_machine/src/Action.cpp +++ b/transition_machine/src/Action.cpp @@ -62,6 +62,8 @@ std::string Action::BasicAction::to_string() void Action::printForDebug(FILE * output) { + fprintf(output, "%s :\n\t", name.c_str()); + for(auto & basic : sequence) fprintf(output, "%s ", basic.to_string().c_str()); fprintf(output, "\n"); diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index d8d9d4b5a1730d3aabd433b370ce5d64e6ba4ada..50109089cbf91b592dd6e910cee64e15f28e1879 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -256,35 +256,23 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na sequence.emplace_back(basicAction3); } - else if(std::string(b1) == "ROOT") + else if(std::string(b1) == "EOS") { auto apply0 = [b2](Config & c, Action::BasicAction &) { int b0 = c.head; - int rootIndex = c.stack.back(); - int eosIndex = rootIndex+1; - auto & tape = c.getTape("LABEL"); - while (eosIndex < (int)tape.hyp.size() && !tape.hyp[eosIndex].empty()) - eosIndex++; - - eosIndex--; - - simpleBufferWrite(c, "EOS", "1", eosIndex-b0); + int s0 = c.stack.back(); + simpleBufferWrite(c, "EOS", "1", s0-b0); }; auto undo0 = [](Config & c, Action::BasicAction) { int b0 = c.head; - int rootIndex = c.stack.back(); - int eosIndex = rootIndex; - auto & tape = c.getTape("EOS"); - while (eosIndex < (int)tape.hyp.size() && tape.hyp[eosIndex] != "1") - eosIndex++; - - simpleBufferWrite(c, "EOS", "", eosIndex-b0); + int s0 = c.stack.back(); + simpleBufferWrite(c, "EOS", "", s0-b0); }; auto appliable0 = [](Config & c, Action::BasicAction &) { - return !c.isFinal(); + return !c.isFinal() && !c.stack.empty(); }; Action::BasicAction basicAction0 = {Action::BasicAction::Type::Write, "", apply0, undo0, appliable0}; @@ -293,23 +281,32 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na auto apply = [](Config & c, Action::BasicAction &) { + auto & govs = c.getTape("GOV"); int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "GOV", "0", s0-b0); + int rootIndex = -1; + for (auto s : c.stack) + if (govs.hyp[s].empty()) + { + rootIndex = s; + break; + } + + simpleBufferWrite(c, "GOV", "0", rootIndex-b0); }; auto undo = [](Config & c, Action::BasicAction &) { + auto & govs = c.getTape("GOV"); int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "GOV", "", s0-b0); + for (auto s : c.stack) + if (govs.hyp[s] == "0") + { + simpleBufferWrite(c, "GOV", "", s-b0); + break; + } }; auto appliable = [](Config & c, Action::BasicAction &) { - if (c.stack.empty()) - return false; - int b0 = c.head; - int s0 = c.stack.back(); - return simpleBufferWriteAppliable(c, "GOV", s0-b0); + return !c.stack.empty(); }; Action::BasicAction basicAction = {Action::BasicAction::Type::Write, "", apply, undo, appliable}; @@ -318,93 +315,61 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na auto apply2 = [b2](Config & c, Action::BasicAction &) { + auto & govs = c.getTape("GOV"); int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "LABEL", "root", s0-b0); + + for (auto s : c.stack) + if (govs.hyp[s] == "0") + { + simpleBufferWrite(c, "LABEL", "root", s-b0); + break; + } }; auto undo2 = [](Config & c, Action::BasicAction &) { + auto & govs = c.getTape("GOV"); int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "LABEL", "", s0-b0); + + for (auto s : c.stack) + if (govs.hyp[s] == "0") + { + simpleBufferWrite(c, "LABEL", "", s-b0); + break; + } }; auto appliable2 = [](Config & c, Action::BasicAction &) { - if (c.stack.empty()) - return false; - int b0 = c.head; - int s0 = c.stack.back(); - return simpleBufferWriteAppliable(c, "LABEL", s0-b0); + return !c.stack.empty(); }; Action::BasicAction basicAction2 = {Action::BasicAction::Type::Write, "", apply2, undo2, appliable2}; sequence.emplace_back(basicAction2); - auto apply3 = [b2](Config & c, Action::BasicAction & ba) - { - ba.data = std::to_string(c.stack.back()); - c.stack.pop_back(); - }; - auto undo3 = [](Config & c, Action::BasicAction & ba) - { - c.stack.push_back(std::stoi(ba.data)); - }; - auto appliable3 = [](Config & c, Action::BasicAction &) - { - return !c.isFinal(); - }; - Action::BasicAction basicAction3 = - {Action::BasicAction::Type::Pop, "", apply3, undo3, appliable3}; - - sequence.emplace_back(basicAction3); - } - else if(std::string(b1) == "EOS") - { - auto apply = [](Config & c, Action::BasicAction &) - { - int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "EOS", "1", s0-b0); - }; - auto undo = [](Config & c, Action::BasicAction &) + auto apply4 = [b2](Config & c, Action::BasicAction & ba) { - int b0 = c.head; - int s0 = c.stack.back(); - simpleBufferWrite(c, "EOS", "", s0-b0); - }; - auto appliable = [](Config & c, Action::BasicAction &) - { - if (c.stack.empty()) - return false; - int b0 = c.head; - int s0 = c.stack.back(); - return simpleBufferWriteAppliable(c, "EOS", s0-b0); - }; - Action::BasicAction basicAction = - {Action::BasicAction::Type::Write, "", apply, undo, appliable}; - - sequence.emplace_back(basicAction); + ba.data = ""; + for (auto & s : c.stack) + ba.data += std::to_string(s) + " "; - auto apply2 = [](Config & c, Action::BasicAction & ba) - { - ba.data = std::to_string(c.stack.back()); - c.stack.pop_back(); + while (!c.stack.empty()) + c.stack.pop_back(); }; - auto undo2 = [](Config & c, Action::BasicAction & ba) + auto undo4 = [](Config & c, Action::BasicAction & ba) { - c.stack.push_back(std::stoi(ba.data)); + auto elems = split(ba.data); + for (int i = elems.size()-1; i >= 0; i--) + c.stack.push_back(std::stoi(elems[i])); }; - auto appliable2 = [](Config & c, Action::BasicAction &) + auto appliable4 = [](Config & c, Action::BasicAction &) { - return !c.isFinal(); + return !c.isFinal() && !c.stack.empty(); }; - Action::BasicAction basicAction2 = - {Action::BasicAction::Type::Pop, "", apply2, undo2, appliable2}; + Action::BasicAction basicAction4 = + {Action::BasicAction::Type::Write, "", apply4, undo4, appliable4}; - sequence.emplace_back(basicAction2); + sequence.emplace_back(basicAction4); } - else invalidNameAndAbort(ERRINFO); @@ -428,7 +393,7 @@ bool ActionBank::simpleBufferWriteAppliable(Config & config, int index = config.head + relativeIndex; - return (!(index < 0 || index >= (int)tape.hyp.size())); + return (!(index < 0 || index >= (int)tape.hyp.size())) && tape.hyp[index].empty(); } void ActionBank::writeRuleResult(Config & config, const std::string & fromTapeName, const std::string & targetTapeName, const std::string & rule, int relativeIndex) diff --git a/transition_machine/src/Classifier.cpp b/transition_machine/src/Classifier.cpp index 8680ff2ddbcab8643a60cf98d751d0c90a6a7d68..965b8058763536d30e69e0cdeb9c2fca364d9aca 100644 --- a/transition_machine/src/Classifier.cpp +++ b/transition_machine/src/Classifier.cpp @@ -155,26 +155,9 @@ int Classifier::getOracleActionIndex(Config & config) return as->getActionIndex(oracle->getAction(config)); } -int Classifier::trainOnBatch(TrainingExamples & examples) +int Classifier::getActionIndex(const std::string & action) { - if(type != Type::Prediction) - { - fprintf(stderr, "ERROR (%s) : classifier \'%s\' cannot be trained. Aborting.\n", ERRINFO, name.c_str()); - exit(1); - } - - return mlp->trainOnBatch(examples); -} - -int Classifier::getScoreOnBatch(TrainingExamples & examples) -{ - if(type != Type::Prediction) - { - fprintf(stderr, "ERROR (%s) : classifier \'%s\' cannot be trained. Aborting.\n", ERRINFO, name.c_str()); - exit(1); - } - - return mlp->getScoreOnBatch(examples); + return as->getActionIndex(action); } std::string Classifier::getActionName(int actionIndex) @@ -239,3 +222,20 @@ void Classifier::printTopology(FILE * output) mlp->printTopology(output); } +std::vector<std::string> Classifier::getZeroCostActions(Config & config) +{ + std::vector<std::string> result; + + for (Action & a : as->actions) + if (a.appliable(config) && oracle->actionIsZeroCost(config, a.name)) + result.emplace_back(a.name); + + return result; +} + +void Classifier::trainOnExample(Config & config, int gold) +{ + auto fd = fm->getFeatureDescription(config); + mlp->update(fd, gold); +} + diff --git a/transition_machine/src/Config.cpp b/transition_machine/src/Config.cpp index e126d27c886d1f7183ccd6229d2212b376c651ba..dee955014e8d1592f8a2685ac8c642033bfb1a9e 100644 --- a/transition_machine/src/Config.cpp +++ b/transition_machine/src/Config.cpp @@ -1,5 +1,6 @@ #include "Config.hpp" #include "File.hpp" +#include <algorithm> Config::Config(BD & bd, const std::string & expPath) : bd(bd), tapes(bd.getNbLines()) { @@ -149,7 +150,7 @@ void Config::moveHead(int mvt) bool Config::isFinal() { - return head >= (int)getTapeByInputCol(0).hyp.size()-1 && stack.empty(); + return head >= (int)getTapeByInputCol(0).hyp.size() -1 && stack.empty(); } void Config::reset() @@ -196,3 +197,35 @@ std::vector<std::string> & Config::getCurrentStateHistory() return actionHistory[getCurrentStateName()]; } +void Config::shuffle(const std::string & delimiterTape, const std::string & delimiter) +{ + auto & tape = getTape(delimiterTape); + std::vector< std::pair<unsigned int, unsigned int> > delimiters; + + unsigned int previousIndex = 0; + for (unsigned int i = 0; i < tape.ref.size(); i++) + if (tape.ref[i] == delimiter) + { + delimiters.emplace_back(previousIndex, i); + previousIndex = i+1; + } + + std::random_shuffle(delimiters.begin(), delimiters.end()); + + std::vector<Tape> newTapes = tapes; + + for (unsigned int tape = 0; tape < tapes.size(); tape++) + { + newTapes[tape].ref.clear(); + newTapes[tape].hyp.clear(); + + for (auto & delimiter : delimiters) + { + std::copy(tapes[tape].ref.begin()+delimiter.first, tapes[tape].ref.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].ref)); + std::copy(tapes[tape].hyp.begin()+delimiter.first, tapes[tape].hyp.begin()+delimiter.second+1, std::back_inserter(newTapes[tape].hyp)); + } + } + + tapes = newTapes; +} + diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index e1d3322fe05a121ad98139f2707870df24dd85a5..3566ae6d54d0b209c91d49c2c6427636808f3e1a 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -5,8 +5,10 @@ std::map< std::string, std::unique_ptr<Oracle> > Oracle::str2oracle; Oracle::Oracle(std::function<void(Oracle *)> initialize, - std::function<std::string(Config &, Oracle *)> findAction) + std::function<std::string(Config &, Oracle *)> findAction, + std::function<bool(Config &, Oracle *, const std::string &)> isZeroCost) { + this->isZeroCost = isZeroCost; this->findAction = findAction; this->initialize = initialize; this->isInit = false; @@ -37,6 +39,14 @@ Oracle * Oracle::getOracle(const std::string & name) return getOracle(name, ""); } +bool Oracle::actionIsZeroCost(Config & config, const std::string & action) +{ + if(!isInit) + init(); + + return isZeroCost(config, this, action); +} + std::string Oracle::getAction(Config & config) { if(!isInit) @@ -67,18 +77,32 @@ void Oracle::createDatabase() [](Oracle *) { }, - [](Config & c, Oracle *) + [](Config &, Oracle *) { - return "WRITE 0 POS " + c.getTape("POS").ref[c.head]; + fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) + { + return action == "WRITE 0 POS " + c.getTape("POS").ref[c.head]; }))); str2oracle.emplace("morpho", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, - [](Config & c, Oracle *) + [](Config &, Oracle *) { - return "WRITE 0 MORPHO " + c.getTape("MORPHO").ref[c.head]; + fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) + { + return action == "WRITE 0 MORPHO " + c.getTape("MORPHO").ref[c.head]; }))); str2oracle.emplace("signature", std::unique_ptr<Oracle>(new Oracle( @@ -122,6 +146,10 @@ void Oracle::createDatabase() } return action; + }, + [](Config &, Oracle *, const std::string &) + { + return true; }))); str2oracle.emplace("lemma_lookup", std::unique_ptr<Oracle>(new Oracle( @@ -160,74 +188,129 @@ void Oracle::createDatabase() return std::string("NOTFOUND"); else return std::string("WRITE 0 LEMMA ") + lemma; + }, + [](Config &, Oracle *, const std::string &) + { + return true; }))); str2oracle.emplace("lemma_rules", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, - [](Config & c, Oracle *) + [](Config &, Oracle *) + { + fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) { const std::string & form = c.getTape("FORM").ref[c.head]; const std::string & lemma = c.getTape("LEMMA").ref[c.head]; std::string rule = getRule(form, lemma); - return std::string("RULE LEMMA ON FORM ") + rule; + return action == std::string("RULE LEMMA ON FORM ") + rule; }))); str2oracle.emplace("parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle *) { }, - [](Config & c, Oracle *) + [](Config &, Oracle *) + { + fprintf(stderr, "ERROR (%s) : getAction called on Oracle of trainable Classifier. Aborting.\n", ERRINFO); + exit(1); + + return std::string(""); + }, + [](Config & c, Oracle *, const std::string & action) { auto & labels = c.getTape("LABEL"); auto & govs = c.getTape("GOV"); auto & eos = c.getTape("EOS"); - auto allDepsBeenPredicted = [&c, &labels, &govs, &eos](int index) + int head = c.head; + int stackHead = c.stack.empty() ? 0 : c.stack.back(); + int stackGov = stackHead + std::stoi(govs.ref[stackHead]); + int headGov = head + std::stoi(govs.ref[head]); + int sentenceStart = c.head; + int sentenceEnd = c.head; + + while(sentenceStart >= 0 && eos.ref[sentenceStart] != "1") + sentenceStart--; + if (sentenceStart != 0) + sentenceStart++; + while(sentenceEnd < (int)eos.ref.size() && eos.ref[sentenceEnd] != "1") + sentenceEnd++; + + auto parts = split(action); + + if (parts[0] == "SHIFT") { - for(int word = index-1; (word >= 0) && eos.ref[word] == "0"; word--) + for (int i = sentenceStart; i <= sentenceEnd; i++) { - if (std::stoi(govs.ref[word])+word == index) - if (govs.ref[word] != govs.hyp[word]) - return false; + int otherGov = i + std::stoi(govs.ref[i]); + + for (auto s : c.stack) + if (s == i) + if (otherGov == head || headGov == s) + return false; } - bool sentenceChange = false; - for(int word = index+1; (word < (int)labels.ref.size()) && !sentenceChange; word++) + return eos.ref[stackHead] != "1" && (eos.ref[head] != "1" || c.stack.empty()); + } + else if (parts[0] == "REDUCE") + { + for (int i = head; i <= sentenceEnd; i++) { - if(eos.ref[word] == "1") - sentenceChange = true; - - if (std::stoi(govs.ref[word])+word == index) - if (govs.ref[word] != govs.hyp[word]) - return false; + int otherGov = i + std::stoi(govs.ref[i]); + if (otherGov == stackHead) + return false; } - return true; - }; - - if (c.stack.empty()) - return std::string("SHIFT"); + return eos.ref[stackHead] != "1"; + } + else if (parts[0] == "LEFT") + { + if (stackGov == head && labels.ref[stackHead] == parts[1]) + return true; - int s0 = c.stack.back(); - int b0 = c.head; + for (int i = head; i <= sentenceEnd; i++) + { + int otherGov = i + std::stoi(govs.ref[i]); + if (otherGov == stackHead || stackGov == i) + return false; + } - if (labels.ref[s0] == "root" && allDepsBeenPredicted(s0)) - return std::string("ROOT"); + return labels.ref[stackHead] == parts[1]; + } + else if (parts[0] == "RIGHT") + { + for (auto s : c.stack) + { + if (s == c.stack.back()) + continue; - if (std::stoi(govs.ref[s0])+s0 == b0) - return std::string("LEFT ") + labels.ref[s0]; + int otherGov = s + std::stoi(govs.ref[s]); + if (otherGov == head || headGov == s) + return false; + } - if (std::stoi(govs.ref[b0])+b0 == s0) - return std::string("RIGHT ") + labels.ref[b0]; + for (int i = head; i <= sentenceEnd; i++) + if (headGov == i) + return false; - if (c.stack.size() > 1) - if (govs.hyp[s0] == govs.ref[s0] && allDepsBeenPredicted(s0)) - return std::string("REDUCE"); + return labels.ref[head] == parts[1]; + } + else if (parts[0] == "EOS") + { + return eos.ref[stackHead] == "1"; + } - return std::string("SHIFT"); + return false; }))); + }