From 4aea9d9dea653b185efecc1b56a13ee6309a96ce Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@etu.univ-amu.fr> Date: Tue, 4 Sep 2018 15:37:24 +0200 Subject: [PATCH] Changed the way .tm files are writen --- decoder/include/Decoder.hpp | 1 + trainer/include/Trainer.hpp | 9 +++++++-- trainer/src/Trainer.cpp | 19 ++++++++++++++++--- trainer/src/macaon_train.cpp | 6 ++++-- transition_machine/src/TransitionMachine.cpp | 15 ++++++++++++--- 5 files changed, 40 insertions(+), 10 deletions(-) diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 1266f43..dc8bef5 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -33,6 +33,7 @@ class Decoder /// @param tm The trained TransitionMachine /// @param bd The BD we need to fill /// @param config The current configuration of the TransitionMachine + /// @param debugMode If true, infos will be printed on stderr. Decoder(TransitionMachine & tm, BD & bd, Config & config, bool debugMode); /// @brief Fill bd using tm. void decode(); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 3a601e3..7225082 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -37,6 +37,9 @@ class Trainer /// Can be nullptr if dev is not used in this training. Config * devConfig; + /// @brief If true, will print infos on stderr + bool debugMode; + public : /// @brief The FeatureDescritpion of a Config. @@ -105,7 +108,8 @@ void processAllExamples( /// @param tm The TransitionMachine to use. /// @param bd The BD to use. /// @param config The config to use. - Trainer(TransitionMachine & tm, BD & bd, Config & config); + /// @param debugMode If true, infos will be printed on stderr. + Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode); /// @brief Construct a new Trainer with a dev set. /// /// @param tm The TransitionMachine to use. @@ -113,7 +117,8 @@ void processAllExamples( /// @param config The Config corresponding to bd. /// @param devBD The BD corresponding to the dev dataset. /// @param devConfig The Config corresponding to devBD. - Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig); + /// @param debugMode If true, infos will be printed on stderr. + Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode); /// @brief Train the TransitionMachine. /// /// @param nbIter The number of training epochs. diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ceb0fa4..12c651b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,16 +1,17 @@ #include "Trainer.hpp" #include "util.hpp" -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config) { this->devBD = nullptr; this->devConfig = nullptr; + this->debugMode = debugMode; } -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { - + this->debugMode = debugMode; } std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config) @@ -25,9 +26,21 @@ std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & Dict::currentClassifierName = classifier->name; classifier->initClassifier(config); + if (debugMode) + { + config.printForDebug(stderr); + fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + } + int neededActionIndex = classifier->getOracleActionIndex(config); std::string neededActionName = classifier->getActionName(neededActionIndex); + if (debugMode) + { + fprintf(stderr, "Action : %s\n", neededActionName.c_str()); + fprintf(stderr, "\n"); + } + if(classifier->needsTrain()) examples[classifier].add(classifier->getFeatureDescription(config), neededActionIndex); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index f362db5..ff47561 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -36,6 +36,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() ("help,h", "Produce this help message") + ("debug,d", "Print infos on stderr") ("dev", po::value<std::string>()->default_value(""), "Development corpus formated according to the MCD") ("lang", po::value<std::string>()->default_value("fr"), @@ -113,6 +114,7 @@ int main(int argc, char * argv[]) int batchSize = vm["batchsize"].as<int>(); int randomSeed = vm["seed"].as<int>(); bool mustShuffle = vm["shuffle"].as<bool>(); + bool debugMode = vm.count("debug") == 0 ? false : true; const char * MACAON_DIR = std::getenv("MACAON_DIR"); std::string slash = "/"; @@ -140,14 +142,14 @@ int main(int argc, char * argv[]) if(devFilename.empty()) { - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig)); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode)); } else { devBD.reset(new BD(BDfilename, MCDfilename)); devConfig.reset(new Config(*devBD.get(), expPath)); devConfig->readInput(devFilename); - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get())); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode)); } trainer->expPath = expPath; diff --git a/transition_machine/src/TransitionMachine.cpp b/transition_machine/src/TransitionMachine.cpp index 29f1fd7..b351ff6 100644 --- a/transition_machine/src/TransitionMachine.cpp +++ b/transition_machine/src/TransitionMachine.cpp @@ -77,7 +77,7 @@ TransitionMachine::TransitionMachine(const std::string & filename, bool trainMod // Reading all transitions int mvt; - while(fscanf(fd, "%s %s %s %d\n", buffer, buffer2, buffer3, &mvt) == 4) + while(fscanf(fd, "%s %s %d %[^\n]\n", buffer, buffer2, &mvt, buffer3) == 4) { std::string src(buffer); std::string dest(buffer2); @@ -116,12 +116,21 @@ TransitionMachine::State * TransitionMachine::getCurrentState() TransitionMachine::Transition * TransitionMachine::getTransition(const std::string & action) { - for (auto & transition : currentState->transitions) + int longestPrefix = -1; + + for (unsigned int i = 0; i < currentState->transitions.size(); i++) { + auto & transition = currentState->transitions[i]; + unsigned int currentMaxLength = longestPrefix >= 0 ? currentState->transitions[longestPrefix].actionPrefix.size() : 0; + if(!strncmp(action.c_str(), transition.actionPrefix.c_str(), transition.actionPrefix.size())) - return &transition; + if (transition.actionPrefix.size() > currentMaxLength) + longestPrefix = i; } + if (longestPrefix != -1) + return ¤tState->transitions[longestPrefix]; + fprintf(stderr, "ERROR (%s) : no corresponding transition for action \'%s\' and state \'%s\'. Aborting.\n", ERRINFO, action.c_str(), currentState->name.c_str()); exit(1); -- GitLab