diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 722508256cde75ca054f1853063302b728bcf43f..9a2f98528c65b4d788cf99d3a3a111f5378e6f54 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -40,6 +40,9 @@ class Trainer /// @brief If true, will print infos on stderr bool debugMode; + /// @brief If true, duplicates examples will be removed from the training set. + bool removeDuplicates; + public : /// @brief The FeatureDescritpion of a Config. @@ -109,7 +112,8 @@ void processAllExamples( /// @param bd The BD to use. /// @param config The config to use. /// @param debugMode If true, infos will be printed on stderr. - Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode); + /// @param removeDuplicates If true, duplicates examples will be removed from the training set. + Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode, bool removeDuplicates); /// @brief Construct a new Trainer with a dev set. /// /// @param tm The TransitionMachine to use. @@ -118,7 +122,8 @@ void processAllExamples( /// @param devBD The BD corresponding to the dev dataset. /// @param devConfig The Config corresponding to devBD. /// @param debugMode If true, infos will be printed on stderr. - Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode); + /// @param removeDuplicates If true, duplicates examples will be removed from the training set. + Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates); /// @brief Train the TransitionMachine. /// /// @param nbIter The number of training epochs. diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index ec2ce9f999d2e5b13bba57ce3e747e71548534d6..69f92f8dfa2c026a057b972b15d720395710afe1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,17 +1,19 @@ #include "Trainer.hpp" #include "util.hpp" -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode, bool removeDuplicates) : tm(tm), trainBD(bd), trainConfig(config) { this->devBD = nullptr; this->devConfig = nullptr; this->debugMode = debugMode; + this->removeDuplicates = removeDuplicates; } -Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) +Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { this->debugMode = debugMode; + this->removeDuplicates = removeDuplicates; } std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config) @@ -54,6 +56,10 @@ std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config.moveHead(transition->headMvt); } + if (removeDuplicates) + for (auto & it : examples) + it.second.removeDuplicates(); + return examples; } diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index ff47561aa7ea6376280254ac933482c01f03767f..2968d876ca2a2a485bb7a76ed327127a3ce20ee0 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -47,6 +47,8 @@ po::options_description getOptionsDescription() "Size of each training batch (in number of examples)") ("seed,s", po::value<int>()->default_value(100), "The random seed that will initialize RNG") + ("duplicates", po::value<bool>()->default_value(true), + "Remove identical training examples") ("shuffle", po::value<bool>()->default_value(true), "Shuffle examples after each iteration"); @@ -114,6 +116,7 @@ int main(int argc, char * argv[]) int batchSize = vm["batchsize"].as<int>(); int randomSeed = vm["seed"].as<int>(); bool mustShuffle = vm["shuffle"].as<bool>(); + bool removeDuplicates = vm["duplicates"].as<bool>(); bool debugMode = vm.count("debug") == 0 ? false : true; const char * MACAON_DIR = std::getenv("MACAON_DIR"); @@ -142,14 +145,14 @@ int main(int argc, char * argv[]) if(devFilename.empty()) { - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode)); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode, removeDuplicates)); } else { devBD.reset(new BD(BDfilename, MCDfilename)); devConfig.reset(new Config(*devBD.get(), expPath)); devConfig->readInput(devFilename); - trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode)); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode, removeDuplicates)); } trainer->expPath = expPath; diff --git a/transition_machine/include/FeatureModel.hpp b/transition_machine/include/FeatureModel.hpp index 13bd9aa9cc3056a16f73b7cd36076936ae413a1e..a7d21fc1698a7433815de06caea32b87effe5400 100644 --- a/transition_machine/include/FeatureModel.hpp +++ b/transition_machine/include/FeatureModel.hpp @@ -51,6 +51,11 @@ class FeatureModel /// /// @param output Where to print. void printForDebug(FILE * output); + + /// @brief Return a string representing this FeatureDescription + /// + /// @return The string representing this FeatureDescription + std::string toString(); }; private : diff --git a/transition_machine/include/TrainingExamples.hpp b/transition_machine/include/TrainingExamples.hpp index eeb30314d2bdd3d1984ea036136899056e7a0d26..4fd52a177419eca6c80a0a951748ea305e8dfa5e 100644 --- a/transition_machine/include/TrainingExamples.hpp +++ b/transition_machine/include/TrainingExamples.hpp @@ -26,6 +26,8 @@ class TrainingExamples TrainingExamples getBatch(unsigned int batchSize); void reset(); void shuffle(); + void remove(int index); + void removeDuplicates(); }; #endif diff --git a/transition_machine/src/FeatureModel.cpp b/transition_machine/src/FeatureModel.cpp index 6f05e056962264b9696abdb906c780462ff9e172..ce272690c21b72abd1772a880438cbe4a3f3ac33 100644 --- a/transition_machine/src/FeatureModel.cpp +++ b/transition_machine/src/FeatureModel.cpp @@ -37,23 +37,40 @@ FeatureModel::FeatureModel(const std::string & filename) void FeatureModel::FeatureDescription::printForDebug(FILE * output) { + fprintf(output, "%s", toString().c_str()); +} + +std::string FeatureModel::FeatureDescription::toString() +{ + char buffer[1024]; + std::string res; + int nbCol = 80; for(int i = 0; i < nbCol; i++) - fprintf(output, "-"); - fprintf(output, "\n"); + res.push_back('-'); + res.push_back('\n'); for(auto featValue : values) { - fprintf(output, "Feature=%s, Policy=%s, Value=%s\n", featValue.name.c_str(), policy2str(featValue.policy), featValue.value->c_str()); + res += "Feature=" + featValue.name; + res += ", Policy=" + std::string(policy2str(featValue.policy)); + res += ", Value=" + std::string(*featValue.value); + res.push_back('\n'); + for(float val : *featValue.vec) - fprintf(output, "%.2f ", val); - fprintf(output, "\n"); + { + sprintf(buffer, "%.2f ", val); + res += std::string(buffer); + } + res.push_back('\n'); } for(int i = 0; i < nbCol; i++) - fprintf(output, "-"); - fprintf(output, "\n"); + res.push_back('-'); + res.push_back('\n'); + + return res; } const char * FeatureModel::policy2str(Policy policy) diff --git a/transition_machine/src/TrainingExamples.cpp b/transition_machine/src/TrainingExamples.cpp index 12ac7f104e196bd9d649c51c43545af97177c1bd..a887acafb0e1222e975fdeb0a4717003790cb621 100644 --- a/transition_machine/src/TrainingExamples.cpp +++ b/transition_machine/src/TrainingExamples.cpp @@ -15,7 +15,7 @@ void TrainingExamples::add(const FeatureModel::FeatureDescription & example, int unsigned int TrainingExamples::size() { - return examples.size(); + return order.size(); } TrainingExamples TrainingExamples::getBatch(unsigned int batchSize) @@ -41,3 +41,32 @@ void TrainingExamples::shuffle() std::random_shuffle(order.begin(), order.end()); } +void TrainingExamples::removeDuplicates() +{ + std::map<std::string, int> lastIndex; + std::map<int, bool> toRemove; + + for (unsigned int i = 0; i < examples.size(); i++) + { + std::string example = examples[i].toString(); + + if (lastIndex.count(example)) + toRemove[i] = true; + else + lastIndex[example] = i; + } + + for (auto & it : toRemove) + remove(it.first); +} + +void TrainingExamples::remove(int index) +{ + for (unsigned int i = 0; i < order.size(); i++) + if ((int)order[i] == index) + { + order[i] = order.back(); + order.pop_back(); + } +} +