Skip to content
Snippets Groups Projects
Commit 20b709cf authored by Franck Dary's avatar Franck Dary
Browse files

made train_error_detector use less memory

parent 173b1131
No related branches found
No related tags found
No related merge requests found
...@@ -178,8 +178,10 @@ macaon_decode --lang " + ProgramParameters::lang + " --tm machine.tm --bd test. ...@@ -178,8 +178,10 @@ macaon_decode --lang " + ProgramParameters::lang + " --tm machine.tm --bd test.
if (system(("ln -f -s " + ProgramParameters::expPath + "decode.sh " + ProgramParameters::langPath + "bin/maca_tm_" + ProgramParameters::expName).c_str())){} if (system(("ln -f -s " + ProgramParameters::expPath + "decode.sh " + ProgramParameters::langPath + "bin/maca_tm_" + ProgramParameters::expName).c_str())){}
} }
std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev(TransitionMachine & tm, std::vector<Config> devConfigs, std::vector<int> & devIsErrors, std::vector<int> &) std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev(TransitionMachine & tm, std::vector<int> & devIsErrors, std::vector<int> &, File & dev, Config & devConfig)
{ {
dev.rewind();
FILE * devPtr = dev.getDescriptor();
tm.reset(); tm.reset();
std::map< std::string, std::pair<int, int> > counts; std::map< std::string, std::pair<int, int> > counts;
...@@ -190,9 +192,18 @@ std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev ...@@ -190,9 +192,18 @@ std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev
std::vector<int> predictions; std::vector<int> predictions;
std::string classifierName; std::string classifierName;
for (unsigned int i = 0; i < devConfigs.size(); i++) int isError, errorIndex;
for (unsigned int i = 0; i < devIsErrors.size(); i++)
{
if (fscanf(devPtr, "%d\t%d\n", &isError, &errorIndex) != 2)
{ {
auto & devConfig = devConfigs[i]; fprintf(stderr, "ERROR (%s) : corpus bad format. Aborting.\n", ERRINFO);
exit(1);
}
devConfig.loadFromFile(dev);
TransitionMachine::State * currentState = tm.getCurrentState(); TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier; Classifier * classifier = currentState->classifier;
devConfig.setCurrentStateName(&currentState->name); devConfig.setCurrentStateName(&currentState->name);
...@@ -271,7 +282,7 @@ std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev ...@@ -271,7 +282,7 @@ std::map<std::string, std::pair<float, std::pair<float, float> > > getScoreOnDev
return scores; return scores;
} }
void printScoresAndSave(FILE * output, std::map< std::string, std::pair<int, int> > & trainCounter, std::map< std::string, float > & scores, TransitionMachine & tm, int curIter, std::map< std::string, float > & bestScores, std::vector<Config> & devConfigs, std::vector<int> & devIsErrors, std::vector<int> & devErrorIndexes) void printScoresAndSave(FILE * output, std::map< std::string, std::pair<int, int> > & trainCounter, std::map< std::string, float > & scores, TransitionMachine & tm, int curIter, std::map< std::string, float > & bestScores, std::vector<int> & devIsErrors, std::vector<int> & devErrorIndexes, File & devFile, Config & config)
{ {
for (auto & it : trainCounter) for (auto & it : trainCounter)
scores[it.first] = 100.0 * it.second.second / it.second.first; scores[it.first] = 100.0 * it.second.second / it.second.first;
...@@ -284,7 +295,7 @@ void printScoresAndSave(FILE * output, std::map< std::string, std::pair<int, int ...@@ -284,7 +295,7 @@ void printScoresAndSave(FILE * output, std::map< std::string, std::pair<int, int
std::map<std::string, bool> saved; std::map<std::string, bool> saved;
auto devScores = getScoreOnDev(tm, devConfigs, devIsErrors, devErrorIndexes); auto devScores = getScoreOnDev(tm, devIsErrors, devErrorIndexes, devFile, config);
for (auto & it : devScores) for (auto & it : devScores)
{ {
...@@ -354,56 +365,55 @@ void launchTraining() ...@@ -354,56 +365,55 @@ void launchTraining()
std::map< std::string, bool > topologyPrinted; std::map< std::string, bool > topologyPrinted;
std::map< std::string, std::pair<int, int> > trainCounter; std::map< std::string, std::pair<int, int> > trainCounter;
int curIter = 0; int curIter = 0;
std::vector<Config> configs;
std::vector<int> isErrors; std::vector<int> isErrors;
std::vector<int> errorIndexes; std::vector<int> errorIndexes;
std::vector<Config> devConfigs;
std::vector<int> devIsErrors; std::vector<int> devIsErrors;
std::vector<int> devErrorIndexes; std::vector<int> devErrorIndexes;
int isError; int isError;
int errorIndex; int errorIndex;
Config config(trainBD);
fprintf(stderr, "Reading train corpus..."); fprintf(stderr, "Reading train corpus...");
while (fscanf(trainPtr, "%d\t%d\n", &isError, &errorIndex) == 2) while (fscanf(trainPtr, "%d\t%d\n", &isError, &errorIndex) == 2)
{ {
configs.emplace_back(trainBD);
isErrors.emplace_back(isError); isErrors.emplace_back(isError);
errorIndexes.emplace_back(errorIndex); errorIndexes.emplace_back(errorIndex);
configs.back().loadFromFile(train); config.loadFromFile(train);
} }
fprintf(stderr, " done !\n"); fprintf(stderr, " done !\n");
fprintf(stderr, "Reading dev corpus..."); fprintf(stderr, "Reading dev corpus...");
while (fscanf(devPtr, "%d\t%d\n", &isError, &errorIndex) == 2) while (fscanf(devPtr, "%d\t%d\n", &isError, &errorIndex) == 2)
{ {
devConfigs.emplace_back(trainBD);
devIsErrors.emplace_back(isError); devIsErrors.emplace_back(isError);
devErrorIndexes.emplace_back(errorIndex); devErrorIndexes.emplace_back(errorIndex);
devConfigs.back().loadFromFile(dev); config.loadFromFile(dev);
} }
fprintf(stderr, " done !\n"); fprintf(stderr, " done !\n");
auto resetAndShuffle = [&configs,&trainCounter]() auto resetAndShuffle = [&trainCounter,&train,&dev,&trainPtr]()
{ {
//TODO shuffle train.rewind();
/* dev.rewind();
if(ProgramParameters::shuffleExamples) trainPtr = train.getDescriptor();
std::random_shuffle(configs.begin(), configs.end());
*/
for (auto & it : trainCounter) for (auto & it : trainCounter)
it.second.first = it.second.second = 0; it.second.first = it.second.second = 0;
}; };
Config trainConfig(trainBD);
while (curIter < ProgramParameters::nbIter) while (curIter < ProgramParameters::nbIter)
{ {
resetAndShuffle(); resetAndShuffle();
for (unsigned int i = 0; i < configs.size(); i++) for (unsigned int i = 0; i < isErrors.size(); i++)
{
if (fscanf(trainPtr, "%d\t%d\n", &isError, &errorIndex) != 2)
{ {
auto & trainConfig = configs[i]; fprintf(stderr, "ERROR (%s) : corpus bad format. Aborting.\n", ERRINFO);
isError = isErrors[i]; exit(1);
errorIndex = errorIndexes[i]; }
trainConfig.loadFromFile(train);
TransitionMachine::State * currentState = tm.getCurrentState(); TransitionMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier; Classifier * classifier = currentState->classifier;
...@@ -420,7 +430,7 @@ void launchTraining() ...@@ -420,7 +430,7 @@ void launchTraining()
// Print current iter advancement in percentage // Print current iter advancement in percentage
if (ProgramParameters::interactive) if (ProgramParameters::interactive)
{ {
int totalSize = configs.size(); int totalSize = isErrors.size();
int steps = i; int steps = i;
if (steps % 200 == 0 || totalSize-steps < 200) if (steps % 200 == 0 || totalSize-steps < 200)
fprintf(stderr, "Current Iteration : %.2f%%\r", 100.0*steps/totalSize); fprintf(stderr, "Current Iteration : %.2f%%\r", 100.0*steps/totalSize);
...@@ -445,7 +455,7 @@ void launchTraining() ...@@ -445,7 +455,7 @@ void launchTraining()
trainCounter[classifier->name].second += pAction == oAction ? 1 : 0; trainCounter[classifier->name].second += pAction == oAction ? 1 : 0;
} }
printScoresAndSave(stderr, trainCounter, scores, tm, curIter, bestScores, devConfigs, devIsErrors, devErrorIndexes); printScoresAndSave(stderr, trainCounter, scores, tm, curIter, bestScores, devIsErrors, devErrorIndexes, dev, config);
curIter++; curIter++;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment