Skip to content
Snippets Groups Projects
Commit 9811325d authored by Franck Dary's avatar Franck Dary
Browse files

First version of a working training

parent a75a3935
No related branches found
No related tags found
No related merge requests found
......@@ -57,6 +57,8 @@ class MLP
dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
dynet::Expression getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label);
void printParameters(FILE * output);
public :
......
......@@ -120,25 +120,16 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
}
dynet::Expression input = dynet::concatenate(expressions);
dynet::Expression output = run(cg, input);
/*
int nbInputs = layers[0].input_dim;
dynet::Expression x = reshape(concatenate_cols(cur_batch),
dynet::Dim({nb_inputs}, cur_batch_size));
dynet::Expression loss_expr = get_loss(x_batch, cur_labels);
dynet::Expression output = run(cg, input);
loss += as_scalar(computation_graph.forward(loss_expr));
nb_samples += cur_batch_size;
computation_graph.backward(loss_expr);
if(trainMode)
{
cg.backward(pickneglogsoftmax(output, goldClass));
trainer.update();
*/
std::vector<float> res = as_vector(cg.forward(output));
}
return res;
return as_vector(cg.forward(output));
}
dynet::DynetParams & MLP::getDefaultParams()
......@@ -158,9 +149,11 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f
if(it != ptr2parameter.end())
return it->second;
ptr2parameter[fv.vec] = model.add_parameters({1,fv.vec->size()}, dynet::ParameterInitFromVector(*fv.vec));
//ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec));
ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1});
it = ptr2parameter.find(fv.vec);
it->second.set_updated(fv.policy == FeatureModel::Policy::Final ? false : true);
// it->second.values()->v = fv.vec->data();
return it->second;
}
......@@ -230,3 +223,24 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
return h;
}
dynet::Expression MLP::getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label)
{
dynet::Expression y = run(cg, x);
return pickneglogsoftmax(y, label);
}
void MLP::printParameters(FILE * output)
{
for(auto & it : ptr2parameter)
{
auto & param = it.second;
dynet::Tensor * tensor = param.values();
float * value = tensor->v;
int dim = tensor->d.size();
fprintf(output, "Param : ");
for(int i = 0; i < dim; i++)
fprintf(output, "%.2f ", value[i]);
fprintf(output, "\n");
}
}
......@@ -29,13 +29,16 @@ class TapeMachine
private :
std::string name;
std::map< std::string, std::unique_ptr<Classifier> > str2classifier;
std::map< std::string, std::unique_ptr<State> > str2state;
State * currentState;
public :
std::string name;
public :
TapeMachine(const std::string & filename);
State * getCurrentState();
Transition * getTransition(const std::string & action);
......
......@@ -72,7 +72,11 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std
for (unsigned int i = 0; i < scores.size(); i++)
result.emplace_back(scores[i], as->actions[i].name);
std::sort(result.begin(), result.end());
std::sort(result.begin(), result.end(),
[](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b)
{
return a.first > b.first;
});
return result;
}
......
......@@ -7,8 +7,14 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
void Trainer::train()
{
for (int i = 0; i < 2; i++)
int nbIter = 20;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
for (int i = 0; i < nbIter; i++)
{
std::map< std::string, std::pair<int, int> > nbExamples;
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
......@@ -20,7 +26,12 @@ void Trainer::train()
std::string neededActionName = classifier->oracle->getAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName);
printWeightedActions(stderr, weightedActions);
//printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
nbExamples[classifier->name].first++;
if(predictedAction == neededActionName)
nbExamples[classifier->name].second++;
//fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
......@@ -29,13 +40,26 @@ void Trainer::train()
config.moveHead(transition->headMvt);
}
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
config.reset();
}
}
void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa)
{
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
for (auto it : wa)
fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment