diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 1266f433bb1178405ac6b6600252d0e87c42f812..dc8bef588841a0b6cb99a1a30ebcce3f6de5b0fb 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -33,6 +33,7 @@ class Decoder /// @param tm The trained TransitionMachine /// @param bd The BD we need to fill /// @param config The current configuration of the TransitionMachine + /// @param debugMode If true, infos will be printed on stderr. Decoder(TransitionMachine & tm, BD & bd, Config & config, bool debugMode); /// @brief Fill bd using tm. void decode(); diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 3a601e3e0289158e11e7b7ea9596603e65de542d..722508256cde75ca054f1853063302b728bcf43f 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -37,6 +37,9 @@ class Trainer /// Can be nullptr if dev is not used in this training. Config * devConfig; + /// @brief If true, will print infos on stderr + bool debugMode; + public : /// @brief The FeatureDescritpion of a Config. @@ -105,7 +108,8 @@ void processAllExamples( /// @param tm The TransitionMachine to use. /// @param bd The BD to use. /// @param config The config to use. - Trainer(TransitionMachine & tm, BD & bd, Config & config); + /// @param debugMode If true, infos will be printed on stderr. + Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode); /// @brief Construct a new Trainer with a dev set. /// /// @param tm The TransitionMachine to use. @@ -113,7 +117,8 @@ void processAllExamples( /// @param config The Config corresponding to bd. /// @param devBD The BD corresponding to the dev dataset. /// @param devConfig The Config corresponding to devBD. - Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig); + /// @param debugMode If true, infos will be printed on stderr. + Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode); /// @brief Train the TransitionMachine. /// /// @param nbIter The number of training epochs. diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ceb0fa4c3bec9dd31b14b59b35007a62214fdf5a..12c651b07c9a48e7ed29cf5119516bf312a1815b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,16 +1,17 @@ #include "Trainer.hpp" #include "util.hpp" -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config) { this->devBD = nullptr; this->devConfig = nullptr; + this->debugMode = debugMode; } -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { - + this->debugMode = debugMode; } std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config) @@ -25,9 +26,21 @@ std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & Dict::currentClassifierName = classifier->name; classifier->initClassifier(config); + if (debugMode) + { + config.printForDebug(stderr); + fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + } + int neededActionIndex = classifier->getOracleActionIndex(config); std::string neededActionName = classifier->getActionName(neededActionIndex); + if (debugMode) + { + fprintf(stderr, "Action : %s\n", neededActionName.c_str()); + fprintf(stderr, "\n"); + } + if(classifier->needsTrain()) examples[classifier].add(classifier->getFeatureDescription(config), neededActionIndex); diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index f362db5ea8bde0e58eab5839e1cda59141c93696..ff47561aa7ea6376280254ac933482c01f03767f 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -36,6 +36,7 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() ("help,h", "Produce this help message") + ("debug,d", "Print infos on stderr") ("dev", po::value<std::string>()->default_value(""), "Development corpus formated according to the MCD") ("lang", po::value<std::string>()->default_value("fr"), @@ -113,6 +114,7 @@ int main(int argc, char * argv[]) int batchSize = vm["batchsize"].as<int>(); int randomSeed = vm["seed"].as<int>(); bool mustShuffle = vm["shuffle"].as<bool>(); + bool debugMode = vm.count("debug") == 0 ? false : true; const char * MACAON_DIR = std::getenv("MACAON_DIR"); std::string slash = "/"; @@ -140,14 +142,14 @@ int main(int argc, char * argv[]) if(devFilename.empty()) { - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig)); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode)); } else { devBD.reset(new BD(BDfilename, MCDfilename)); devConfig.reset(new Config(*devBD.get(), expPath)); devConfig->readInput(devFilename); - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get())); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode)); } trainer->expPath = expPath; diff --git a/transition_machine/src/TransitionMachine.cpp b/transition_machine/src/TransitionMachine.cpp index 29f1fd7a492673cfd0e3142d7b3dc0e348d87696..b351ff6b8999ed21acc2fc78f62046f20968c61b 100644 --- a/transition_machine/src/TransitionMachine.cpp +++ b/transition_machine/src/TransitionMachine.cpp @@ -77,7 +77,7 @@ TransitionMachine::TransitionMachine(const std::string & filename, bool trainMod // Reading all transitions int mvt; - while(fscanf(fd, "%s %s %s %d\n", buffer, buffer2, buffer3, &mvt) == 4) + while(fscanf(fd, "%s %s %d %[^\n]\n", buffer, buffer2, &mvt, buffer3) == 4) { std::string src(buffer); std::string dest(buffer2); @@ -116,12 +116,21 @@ TransitionMachine::State * TransitionMachine::getCurrentState() TransitionMachine::Transition * TransitionMachine::getTransition(const std::string & action) { - for (auto & transition : currentState->transitions) + int longestPrefix = -1; + + for (unsigned int i = 0; i < currentState->transitions.size(); i++) { + auto & transition = currentState->transitions[i]; + unsigned int currentMaxLength = longestPrefix >= 0 ? currentState->transitions[longestPrefix].actionPrefix.size() : 0; + if(!strncmp(action.c_str(), transition.actionPrefix.c_str(), transition.actionPrefix.size())) - return &transition; + if (transition.actionPrefix.size() > currentMaxLength) + longestPrefix = i; } + if (longestPrefix != -1) + return ¤tState->transitions[longestPrefix]; + fprintf(stderr, "ERROR (%s) : no corresponding transition for action \'%s\' and state \'%s\'. Aborting.\n", ERRINFO, action.c_str(), currentState->name.c_str()); exit(1);