Skip to content
Snippets Groups Projects
Commit c17a1abf authored by Franck Dary's avatar Franck Dary
Browse files

Improved training speed (x2)

parent 84d541e4
Branches
No related tags found
No related merge requests found
......@@ -58,13 +58,13 @@ void checkAndRecordError(Config & config, Classifier * classifier, Classifier::W
}
}
void printAdvancement(Config & config, float currentSpeed)
void printAdvancement(Config & config, float currentSpeed, int nbActionsCutoff)
{
if (ProgramParameters::interactive)
{
int totalSize = ProgramParameters::tapeSize;
int steps = config.getHead();
if (steps && (steps % 200 == 0 || totalSize-steps < 200))
if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
fprintf(stderr, "Decode : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
}
}
......@@ -213,7 +213,7 @@ void Decoder::decodeNoBeam()
auto weightedActions = tm.getCurrentClassifier()->weightActions(config);
printAdvancement(config, currentSpeed);
printAdvancement(config, currentSpeed, nbActionsCutoff);
printDebugInfos(stderr, config, tm, weightedActions);
std::pair<float,std::string> predictedAction;
......@@ -343,7 +343,7 @@ void Decoder::decodeBeam()
node->weightedActions = node->tm.getCurrentClassifier()->weightActions(node->config);
printAdvancement(node->config, currentSpeed);
printAdvancement(node->config, currentSpeed, nbActionsCutoff);
unsigned int nbActionsMax = std::min(std::max(node->tm.getCurrentClassifier()->getNbActions(),(unsigned int)1),(unsigned int)ProgramParameters::nbChilds);
for (unsigned int actionIndex = 0; actionIndex < nbActionsMax; actionIndex++)
......
......@@ -10,6 +10,14 @@
#include <string>
#include "FeatureModel.hpp"
struct BatchNotFull : public std::exception
{
const char * what() const throw()
{
return "Current batch is not full, no need to update.";
}
};
class NeuralNetwork
{
public :
......
......@@ -48,12 +48,16 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
}
float MLP::update(FeatureModel::FeatureDescription & fd, int gold)
{
try
{
float loss = mlp.update(fd, gold);
trainer->update();
return loss;
} catch (BatchNotFull &)
{
return 0.0;
}
}
void MLP::save(const std::string & filename)
......
......@@ -95,7 +95,7 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, int gold)
golds.emplace_back(gold);
if ((int)fds.size() < ProgramParameters::batchSize)
return 0.0;
throw BatchNotFull();
std::vector<dynet::Expression> inputs;
dynet::ComputationGraph cg;
......
......@@ -27,7 +27,7 @@ void Trainer::computeScoreOnDev()
float entropyAccumulator = 0.0;
bool justFlipped = false;
int nbActions = 0;
int nbActionsCutoff = 200;
int nbActionsCutoff = 2*ProgramParameters::batchSize;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
std::vector<float> entropies;
......@@ -56,7 +56,7 @@ void Trainer::computeScoreOnDev()
{
int totalSize = ProgramParameters::devTapeSize;
int steps = devConfig->getHead();
if (steps && (steps % 200 == 0 || totalSize-steps < 200))
if (steps && (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff))
{
fprintf(stderr, " \r");
fprintf(stderr, "Eval on dev : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
......@@ -162,7 +162,7 @@ void Trainer::train()
int nbSteps = 0;
int nbActions = 0;
int nbActionsCutoff = 200;
int nbActionsCutoff = 2*ProgramParameters::batchSize;
float currentSpeed = 0.0;
auto pastTime = std::chrono::high_resolution_clock::now();
while (TI.getEpoch() <= ProgramParameters::nbIter)
......@@ -204,7 +204,7 @@ void Trainer::train()
{
int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize;
int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps;
if (steps % 200 == 0 || totalSize-steps < 200)
if (steps % nbActionsCutoff == 0 || totalSize-steps < nbActionsCutoff)
{
fprintf(stderr, " \r");
fprintf(stderr, "Current Iteration : %.2f%% speed : %s actions/s\r", 100.0*steps/totalSize, int2humanStr((int)currentSpeed).c_str());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment