Skip to content
Snippets Groups Projects
Commit 760f9ec8 authored by Franck Dary's avatar Franck Dary
Browse files

Training is now working with batches

parent 9811325d
No related branches found
No related tags found
No related merge requests found
......@@ -57,13 +57,14 @@ class MLP
dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
dynet::Expression getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label);
void printParameters(FILE * output);
public :
MLP(std::vector<Layer> layers);
std::vector<float> predict(FeatureModel::FeatureDescription & fd, int goldClass);
int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end);
};
#endif
......@@ -2,6 +2,7 @@
#include "util.hpp"
#include <dynet/param-init.h>
#include <dynet/io.h>
std::string MLP::activation2str(Activation a)
{
......@@ -150,7 +151,7 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f
return it->second;
//ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec));
ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1});
ptr2parameter[fv.vec] = model.add_parameters({(unsigned)fv.vec->size(),1});
it = ptr2parameter.find(fv.vec);
// it->second.values()->v = fv.vec->data();
......@@ -223,12 +224,6 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
return h;
}
dynet::Expression MLP::getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label)
{
dynet::Expression y = run(cg, x);
return pickneglogsoftmax(y, label);
}
void MLP::printParameters(FILE * output)
{
for(auto & it : ptr2parameter)
......@@ -244,3 +239,62 @@ void MLP::printParameters(FILE * output)
}
}
int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end)
{
dynet::ComputationGraph cg;
std::vector<dynet::Expression> inputs;
std::vector<unsigned int> goldClasses;
int inputDim = 0;
int outputDim = layers.back().output_dim;
for(auto it = start; it != end; it++)
{
std::vector<dynet::Expression> expressions;
expressions.clear();
for (auto & featValue : it->second.values)
{
if(featValue.policy == FeatureModel::Policy::Final)
expressions.emplace_back(dynet::const_parameter(cg, featValue2parameter(featValue)));
else
expressions.emplace_back(dynet::parameter(cg, featValue2parameter(featValue)));
}
inputs.emplace_back(dynet::concatenate(expressions));
inputDim = inputs.back().dim().rows();
goldClasses.emplace_back((unsigned)it->first);
}
dynet::Expression concatenation = dynet::concatenate(inputs);
int batchSize = end - start;
dynet::Expression batchedInput = reshape((concatenation),
dynet::Dim({(unsigned)inputDim}, batchSize));
dynet::Expression output = run(cg, batchedInput);
if(trainMode)
{
dynet::Expression batchedLoss = pickneglogsoftmax(output, goldClasses);
dynet::Expression loss = sum_batches(batchedLoss);
cg.backward(loss);
trainer.update();
}
int nbCorrect = 0;
std::vector<float> predictions = as_vector(output.value());
for (unsigned int i = 0; (int)i < batchSize; i++)
{
int prediction = 0;
for (unsigned int j = 0; (int)j < outputDim; j++)
if(predictions[i*outputDim+j] > predictions[i*outputDim+prediction])
prediction = (int)j;
if(prediction == (int)goldClasses[i])
nbCorrect++;
}
return nbCorrect;
}
......@@ -17,6 +17,7 @@ class ActionSet
ActionSet(const std::string & filename);
void printForDebug(FILE * output);
int getActionIndex(const std::string & name);
std::string getActionName(int actionIndex);
};
#endif
......@@ -22,21 +22,26 @@ class Classifier
};
std::string name;
private :
Type type;
std::unique_ptr<FeatureModel> fm;
std::unique_ptr<ActionSet> as;
std::unique_ptr<MLP> mlp;
Oracle * oracle;
private :
void initClassifier(Config & config);
public :
static Type str2type(const std::string & filename);
Classifier(const std::string & filename);
WeightedActions weightActions(Config & config, const std::string & goldAction);
FeatureModel::FeatureDescription getFeatureDescription(Config & config);
std::string getOracleAction(Config & config);
int getOracleActionIndex(Config & config);
int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end);
std::string getActionName(int actionIndex);
void initClassifier(Config & config);
};
#endif
......@@ -40,3 +40,17 @@ int ActionSet::getActionIndex(const std::string & name)
return -1;
}
std::string ActionSet::getActionName(int actionIndex)
{
if(actionIndex >= 0 && actionIndex < (int)actions.size())
{
return actions[actionIndex].name;
}
fprintf(stderr, "ERROR (%s) : invalid action index \'%d\'. Aborting.\n", ERRINFO, actionIndex);
exit(1);
return "";
}
......@@ -59,7 +59,6 @@ Classifier::Type Classifier::str2type(const std::string & s)
Classifier::WeightedActions Classifier::weightActions(Config & config, const std::string & goldAction)
{
if(!mlp.get())
initClassifier(config);
int actionIndex = as->getActionIndex(goldAction);
......@@ -83,6 +82,9 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std
void Classifier::initClassifier(Config & config)
{
if(mlp.get())
return;
int nbInputs = 0;
int nbHidden = 200;
int nbOutputs = as->actions.size();
......@@ -96,3 +98,28 @@ void Classifier::initClassifier(Config & config)
{nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}}));
}
FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config)
{
return fm->getFeatureDescription(config);
}
std::string Classifier::getOracleAction(Config & config)
{
return oracle->getAction(config);
}
int Classifier::getOracleActionIndex(Config & config)
{
return as->getActionIndex(oracle->getAction(config));
}
int Classifier::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end)
{
return mlp->trainOnBatch(start, end);
}
std::string Classifier::getActionName(int actionIndex)
{
return as->getActionName(actionIndex);
}
......@@ -16,6 +16,8 @@ class Trainer
private :
void printWeightedActions(FILE * output, Classifier::WeightedActions & wa);
void trainUnbatched();
void trainBatched();
public :
......
......@@ -5,9 +5,9 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
{
}
void Trainer::train()
void Trainer::trainUnbatched()
{
int nbIter = 20;
int nbIter = 5;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
......@@ -24,7 +24,7 @@ void Trainer::train()
//fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
std::string neededActionName = classifier->oracle->getAction(config);
std::string neededActionName = classifier->getOracleAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName);
//printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
......@@ -48,6 +48,67 @@ void Trainer::train()
}
}
void Trainer::trainBatched()
{
using FD = FeatureModel::FeatureDescription;
using Example = std::pair<int, FD>;
std::map<Classifier*, std::vector<Example> > examples;
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
classifier->initClassifier(config);
int neededActionIndex = classifier->getOracleActionIndex(config);
std::string neededActionName = classifier->getActionName(neededActionIndex);
examples[classifier].emplace_back(Example(neededActionIndex, classifier->getFeatureDescription(config)));
TapeMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition);
config.moveHead(transition->headMvt);
}
int nbIter = 20;
int batchSize = 50;
for (int i = 0; i < nbIter; i++)
{
std::map< std::string, std::pair<int, int> > nbExamples;
for(auto & it : examples)
{
int nbBatches = (it.second.size() / batchSize) + (it.second.size() % batchSize ? 1 : 0);
for(int numBatch = 0; numBatch < nbBatches; numBatch++)
{
int currentBatchSize = std::min<int>(batchSize, it.second.size() - (numBatch * batchSize));
auto batchStart = it.second.begin() + numBatch * batchSize;
auto batchEnd = batchStart + currentBatchSize;
int nbCorrect = it.first->trainOnBatch(batchStart, batchEnd);
nbExamples[it.first->name].first += currentBatchSize;
nbExamples[it.first->name].second += nbCorrect;
}
}
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
}
}
void Trainer::train()
{
// trainUnbatched();
trainBatched();
}
void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa)
{
int nbCols = 80;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment