Skip to content
Snippets Groups Projects
Commit 5c8183f2 authored by Franck Dary's avatar Franck Dary
Browse files

Added an option to remove duplicates training examples

parent 1e0807ce
No related branches found
No related tags found
No related merge requests found
......@@ -40,6 +40,9 @@ class Trainer
/// @brief If true, will print infos on stderr
bool debugMode;
/// @brief If true, duplicates examples will be removed from the training set.
bool removeDuplicates;
public :
/// @brief The FeatureDescritpion of a Config.
......@@ -109,7 +112,8 @@ void processAllExamples(
/// @param bd The BD to use.
/// @param config The config to use.
/// @param debugMode If true, infos will be printed on stderr.
Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode);
/// @param removeDuplicates If true, duplicates examples will be removed from the training set.
Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode, bool removeDuplicates);
/// @brief Construct a new Trainer with a dev set.
///
/// @param tm The TransitionMachine to use.
......@@ -118,7 +122,8 @@ void processAllExamples(
/// @param devBD The BD corresponding to the dev dataset.
/// @param devConfig The Config corresponding to devBD.
/// @param debugMode If true, infos will be printed on stderr.
Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode);
/// @param removeDuplicates If true, duplicates examples will be removed from the training set.
Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates);
/// @brief Train the TransitionMachine.
///
/// @param nbIter The number of training epochs.
......
#include "Trainer.hpp"
#include "util.hpp"
Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode)
Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, bool debugMode, bool removeDuplicates)
: tm(tm), trainBD(bd), trainConfig(config)
{
this->devBD = nullptr;
this->devConfig = nullptr;
this->debugMode = debugMode;
this->removeDuplicates = removeDuplicates;
}
Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig)
Trainer::Trainer(TransitionMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig, bool debugMode, bool removeDuplicates) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig)
{
this->debugMode = debugMode;
this->removeDuplicates = removeDuplicates;
}
std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config & config)
......@@ -54,6 +56,10 @@ std::map<Classifier*,TrainingExamples> Trainer::getExamplesByClassifier(Config &
config.moveHead(transition->headMvt);
}
if (removeDuplicates)
for (auto & it : examples)
it.second.removeDuplicates();
return examples;
}
......
......@@ -47,6 +47,8 @@ po::options_description getOptionsDescription()
"Size of each training batch (in number of examples)")
("seed,s", po::value<int>()->default_value(100),
"The random seed that will initialize RNG")
("duplicates", po::value<bool>()->default_value(true),
"Remove identical training examples")
("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration");
......@@ -114,6 +116,7 @@ int main(int argc, char * argv[])
int batchSize = vm["batchsize"].as<int>();
int randomSeed = vm["seed"].as<int>();
bool mustShuffle = vm["shuffle"].as<bool>();
bool removeDuplicates = vm["duplicates"].as<bool>();
bool debugMode = vm.count("debug") == 0 ? false : true;
const char * MACAON_DIR = std::getenv("MACAON_DIR");
......@@ -142,14 +145,14 @@ int main(int argc, char * argv[])
if(devFilename.empty())
{
trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode));
trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, debugMode, removeDuplicates));
}
else
{
devBD.reset(new BD(BDfilename, MCDfilename));
devConfig.reset(new Config(*devBD.get(), expPath));
devConfig->readInput(devFilename);
trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode));
trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get(), debugMode, removeDuplicates));
}
trainer->expPath = expPath;
......
......@@ -51,6 +51,11 @@ class FeatureModel
///
/// @param output Where to print.
void printForDebug(FILE * output);
/// @brief Return a string representing this FeatureDescription
///
/// @return The string representing this FeatureDescription
std::string toString();
};
private :
......
......@@ -26,6 +26,8 @@ class TrainingExamples
TrainingExamples getBatch(unsigned int batchSize);
void reset();
void shuffle();
void remove(int index);
void removeDuplicates();
};
#endif
......@@ -37,23 +37,40 @@ FeatureModel::FeatureModel(const std::string & filename)
void FeatureModel::FeatureDescription::printForDebug(FILE * output)
{
fprintf(output, "%s", toString().c_str());
}
std::string FeatureModel::FeatureDescription::toString()
{
char buffer[1024];
std::string res;
int nbCol = 80;
for(int i = 0; i < nbCol; i++)
fprintf(output, "-");
fprintf(output, "\n");
res.push_back('-');
res.push_back('\n');
for(auto featValue : values)
{
fprintf(output, "Feature=%s, Policy=%s, Value=%s\n", featValue.name.c_str(), policy2str(featValue.policy), featValue.value->c_str());
res += "Feature=" + featValue.name;
res += ", Policy=" + std::string(policy2str(featValue.policy));
res += ", Value=" + std::string(*featValue.value);
res.push_back('\n');
for(float val : *featValue.vec)
fprintf(output, "%.2f ", val);
fprintf(output, "\n");
{
sprintf(buffer, "%.2f ", val);
res += std::string(buffer);
}
res.push_back('\n');
}
for(int i = 0; i < nbCol; i++)
fprintf(output, "-");
fprintf(output, "\n");
res.push_back('-');
res.push_back('\n');
return res;
}
const char * FeatureModel::policy2str(Policy policy)
......
......@@ -15,7 +15,7 @@ void TrainingExamples::add(const FeatureModel::FeatureDescription & example, int
unsigned int TrainingExamples::size()
{
return examples.size();
return order.size();
}
TrainingExamples TrainingExamples::getBatch(unsigned int batchSize)
......@@ -41,3 +41,32 @@ void TrainingExamples::shuffle()
std::random_shuffle(order.begin(), order.end());
}
void TrainingExamples::removeDuplicates()
{
std::map<std::string, int> lastIndex;
std::map<int, bool> toRemove;
for (unsigned int i = 0; i < examples.size(); i++)
{
std::string example = examples[i].toString();
if (lastIndex.count(example))
toRemove[i] = true;
else
lastIndex[example] = i;
}
for (auto & it : toRemove)
remove(it.first);
}
void TrainingExamples::remove(int index)
{
for (unsigned int i = 0; i < order.size(); i++)
if ((int)order[i] == index)
{
order[i] = order.back();
order.pop_back();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment