Skip to content
Snippets Groups Projects
Commit 9811325d authored by Franck Dary's avatar Franck Dary
Browse files

First version of a working training

parent a75a3935
No related branches found
No related tags found
No related merge requests found
...@@ -57,6 +57,8 @@ class MLP ...@@ -57,6 +57,8 @@ class MLP
dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv); dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x); dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f); inline dynet::Expression activate(dynet::Expression h, Activation f);
dynet::Expression getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label);
void printParameters(FILE * output);
public : public :
......
...@@ -120,25 +120,16 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC ...@@ -120,25 +120,16 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
} }
dynet::Expression input = dynet::concatenate(expressions); dynet::Expression input = dynet::concatenate(expressions);
dynet::Expression output = run(cg, input);
/*
int nbInputs = layers[0].input_dim;
dynet::Expression x = reshape(concatenate_cols(cur_batch), dynet::Expression output = run(cg, input);
dynet::Dim({nb_inputs}, cur_batch_size));
dynet::Expression loss_expr = get_loss(x_batch, cur_labels);
loss += as_scalar(computation_graph.forward(loss_expr)); if(trainMode)
nb_samples += cur_batch_size; {
computation_graph.backward(loss_expr); cg.backward(pickneglogsoftmax(output, goldClass));
trainer.update(); trainer.update();
*/ }
std::vector<float> res = as_vector(cg.forward(output));
return res; return as_vector(cg.forward(output));
} }
dynet::DynetParams & MLP::getDefaultParams() dynet::DynetParams & MLP::getDefaultParams()
...@@ -158,9 +149,11 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f ...@@ -158,9 +149,11 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f
if(it != ptr2parameter.end()) if(it != ptr2parameter.end())
return it->second; return it->second;
ptr2parameter[fv.vec] = model.add_parameters({1,fv.vec->size()}, dynet::ParameterInitFromVector(*fv.vec)); //ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec));
ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1});
it = ptr2parameter.find(fv.vec); it = ptr2parameter.find(fv.vec);
it->second.set_updated(fv.policy == FeatureModel::Policy::Final ? false : true);
// it->second.values()->v = fv.vec->data();
return it->second; return it->second;
} }
...@@ -230,3 +223,24 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f) ...@@ -230,3 +223,24 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
return h; return h;
} }
dynet::Expression MLP::getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label)
{
dynet::Expression y = run(cg, x);
return pickneglogsoftmax(y, label);
}
void MLP::printParameters(FILE * output)
{
for(auto & it : ptr2parameter)
{
auto & param = it.second;
dynet::Tensor * tensor = param.values();
float * value = tensor->v;
int dim = tensor->d.size();
fprintf(output, "Param : ");
for(int i = 0; i < dim; i++)
fprintf(output, "%.2f ", value[i]);
fprintf(output, "\n");
}
}
...@@ -29,13 +29,16 @@ class TapeMachine ...@@ -29,13 +29,16 @@ class TapeMachine
private : private :
std::string name;
std::map< std::string, std::unique_ptr<Classifier> > str2classifier; std::map< std::string, std::unique_ptr<Classifier> > str2classifier;
std::map< std::string, std::unique_ptr<State> > str2state; std::map< std::string, std::unique_ptr<State> > str2state;
State * currentState; State * currentState;
public : public :
std::string name;
public :
TapeMachine(const std::string & filename); TapeMachine(const std::string & filename);
State * getCurrentState(); State * getCurrentState();
Transition * getTransition(const std::string & action); Transition * getTransition(const std::string & action);
......
...@@ -72,7 +72,11 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std ...@@ -72,7 +72,11 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std
for (unsigned int i = 0; i < scores.size(); i++) for (unsigned int i = 0; i < scores.size(); i++)
result.emplace_back(scores[i], as->actions[i].name); result.emplace_back(scores[i], as->actions[i].name);
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end(),
[](const std::pair<float, std::string> & a, const std::pair<float, std::string> & b)
{
return a.first > b.first;
});
return result; return result;
} }
......
...@@ -7,8 +7,14 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config) ...@@ -7,8 +7,14 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
void Trainer::train() void Trainer::train()
{ {
for (int i = 0; i < 2; i++) int nbIter = 20;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
for (int i = 0; i < nbIter; i++)
{ {
std::map< std::string, std::pair<int, int> > nbExamples;
while (!config.isFinal()) while (!config.isFinal())
{ {
TapeMachine::State * currentState = tm.getCurrentState(); TapeMachine::State * currentState = tm.getCurrentState();
...@@ -20,7 +26,12 @@ void Trainer::train() ...@@ -20,7 +26,12 @@ void Trainer::train()
std::string neededActionName = classifier->oracle->getAction(config); std::string neededActionName = classifier->oracle->getAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName); auto weightedActions = classifier->weightActions(config, neededActionName);
printWeightedActions(stderr, weightedActions); //printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
nbExamples[classifier->name].first++;
if(predictedAction == neededActionName)
nbExamples[classifier->name].second++;
//fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str()); //fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
...@@ -29,13 +40,26 @@ void Trainer::train() ...@@ -29,13 +40,26 @@ void Trainer::train()
config.moveHead(transition->headMvt); config.moveHead(transition->headMvt);
} }
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
config.reset(); config.reset();
} }
} }
void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa) void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa)
{ {
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
for (auto it : wa) for (auto it : wa)
fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str()); fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment