diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp index 1c205fc1bbb10d6f663326074984923047b708db..40c6ba706e46a950a2ea0c08bb15b58aa772ce7d 100644 --- a/MLP/include/MLP.hpp +++ b/MLP/include/MLP.hpp @@ -57,13 +57,14 @@ class MLP dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv); dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x); inline dynet::Expression activate(dynet::Expression h, Activation f); - dynet::Expression getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label); void printParameters(FILE * output); public : MLP(std::vector<Layer> layers); std::vector<float> predict(FeatureModel::FeatureDescription & fd, int goldClass); + + int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end); }; #endif diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp index e6bef2afd37db09520f9146fa44609db4c5f1bea..11b993bf0e1a91f976d87f7808e5040e41a4a317 100644 --- a/MLP/src/MLP.cpp +++ b/MLP/src/MLP.cpp @@ -2,6 +2,7 @@ #include "util.hpp" #include <dynet/param-init.h> +#include <dynet/io.h> std::string MLP::activation2str(Activation a) { @@ -150,7 +151,7 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f return it->second; //ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec)); - ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}); + ptr2parameter[fv.vec] = model.add_parameters({(unsigned)fv.vec->size(),1}); it = ptr2parameter.find(fv.vec); // it->second.values()->v = fv.vec->data(); @@ -223,12 +224,6 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f) return h; } -dynet::Expression MLP::getLoss(dynet::ComputationGraph & cg, dynet::Expression x, unsigned int label) -{ - dynet::Expression y = run(cg, x); - return pickneglogsoftmax(y, label); -} - void MLP::printParameters(FILE * output) { for(auto & it : ptr2parameter) @@ -244,3 +239,62 @@ void MLP::printParameters(FILE * output) } } +int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end) +{ + dynet::ComputationGraph cg; + std::vector<dynet::Expression> inputs; + std::vector<unsigned int> goldClasses; + int inputDim = 0; + int outputDim = layers.back().output_dim; + + for(auto it = start; it != end; it++) + { + std::vector<dynet::Expression> expressions; + expressions.clear(); + + for (auto & featValue : it->second.values) + { + if(featValue.policy == FeatureModel::Policy::Final) + expressions.emplace_back(dynet::const_parameter(cg, featValue2parameter(featValue))); + else + expressions.emplace_back(dynet::parameter(cg, featValue2parameter(featValue))); + } + + inputs.emplace_back(dynet::concatenate(expressions)); + inputDim = inputs.back().dim().rows(); + goldClasses.emplace_back((unsigned)it->first); + } + + dynet::Expression concatenation = dynet::concatenate(inputs); + int batchSize = end - start; + + dynet::Expression batchedInput = reshape((concatenation), + dynet::Dim({(unsigned)inputDim}, batchSize)); + + dynet::Expression output = run(cg, batchedInput); + + if(trainMode) + { + dynet::Expression batchedLoss = pickneglogsoftmax(output, goldClasses); + dynet::Expression loss = sum_batches(batchedLoss); + cg.backward(loss); + trainer.update(); + } + + int nbCorrect = 0; + std::vector<float> predictions = as_vector(output.value()); + for (unsigned int i = 0; (int)i < batchSize; i++) + { + int prediction = 0; + + for (unsigned int j = 0; (int)j < outputDim; j++) + if(predictions[i*outputDim+j] > predictions[i*outputDim+prediction]) + prediction = (int)j; + + if(prediction == (int)goldClasses[i]) + nbCorrect++; + } + + return nbCorrect; +} + diff --git a/tape_machine/include/ActionSet.hpp b/tape_machine/include/ActionSet.hpp index 8c6099ec5c3ae7d8b6e4ccc815cd8fe24d64d3c0..e00f22a0bd7d2fa8ee330149bbfdcd5439c9cb05 100644 --- a/tape_machine/include/ActionSet.hpp +++ b/tape_machine/include/ActionSet.hpp @@ -17,6 +17,7 @@ class ActionSet ActionSet(const std::string & filename); void printForDebug(FILE * output); int getActionIndex(const std::string & name); + std::string getActionName(int actionIndex); }; #endif diff --git a/tape_machine/include/Classifier.hpp b/tape_machine/include/Classifier.hpp index 2772453fed6f52dd7513530d0a4a52c7e9fcbbec..5b913c9e90f16e5a242f2db90d2276aba66c8191 100644 --- a/tape_machine/include/Classifier.hpp +++ b/tape_machine/include/Classifier.hpp @@ -22,21 +22,26 @@ class Classifier }; std::string name; + + private : + Type type; std::unique_ptr<FeatureModel> fm; std::unique_ptr<ActionSet> as; std::unique_ptr<MLP> mlp; Oracle * oracle; - private : - - void initClassifier(Config & config); - public : static Type str2type(const std::string & filename); Classifier(const std::string & filename); WeightedActions weightActions(Config & config, const std::string & goldAction); + FeatureModel::FeatureDescription getFeatureDescription(Config & config); + std::string getOracleAction(Config & config); + int getOracleActionIndex(Config & config); + int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end); + std::string getActionName(int actionIndex); + void initClassifier(Config & config); }; #endif diff --git a/tape_machine/src/ActionSet.cpp b/tape_machine/src/ActionSet.cpp index b2dd09836eb5292b3cc963c9ac0646f515c87ba8..8932d8ca03fbbbbd3138a107e171d713e537ae91 100644 --- a/tape_machine/src/ActionSet.cpp +++ b/tape_machine/src/ActionSet.cpp @@ -40,3 +40,17 @@ int ActionSet::getActionIndex(const std::string & name) return -1; } +std::string ActionSet::getActionName(int actionIndex) +{ + if(actionIndex >= 0 && actionIndex < (int)actions.size()) + { + return actions[actionIndex].name; + } + + fprintf(stderr, "ERROR (%s) : invalid action index \'%d\'. Aborting.\n", ERRINFO, actionIndex); + + exit(1); + + return ""; +} + diff --git a/tape_machine/src/Classifier.cpp b/tape_machine/src/Classifier.cpp index f728c1491fde22cac3500ec8a6a02f2b8897ab2f..7f6cd608d251b4e4872f6e6101de36c19cb122d5 100644 --- a/tape_machine/src/Classifier.cpp +++ b/tape_machine/src/Classifier.cpp @@ -59,8 +59,7 @@ Classifier::Type Classifier::str2type(const std::string & s) Classifier::WeightedActions Classifier::weightActions(Config & config, const std::string & goldAction) { - if(!mlp.get()) - initClassifier(config); + initClassifier(config); int actionIndex = as->getActionIndex(goldAction); @@ -83,6 +82,9 @@ Classifier::WeightedActions Classifier::weightActions(Config & config, const std void Classifier::initClassifier(Config & config) { + if(mlp.get()) + return; + int nbInputs = 0; int nbHidden = 200; int nbOutputs = as->actions.size(); @@ -96,3 +98,28 @@ void Classifier::initClassifier(Config & config) {nbHidden, nbOutputs, 0.0, MLP::Activation::LINEAR}})); } +FeatureModel::FeatureDescription Classifier::getFeatureDescription(Config & config) +{ + return fm->getFeatureDescription(config); +} + +std::string Classifier::getOracleAction(Config & config) +{ + return oracle->getAction(config); +} + +int Classifier::getOracleActionIndex(Config & config) +{ + return as->getActionIndex(oracle->getAction(config)); +} + +int Classifier::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end) +{ + return mlp->trainOnBatch(start, end); +} + +std::string Classifier::getActionName(int actionIndex) +{ + return as->getActionName(actionIndex); +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index f2246148da08fdf128d40652c00d790c7370808b..361d215354f92720b4975339af51005961d2fa30 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -16,6 +16,8 @@ class Trainer private : void printWeightedActions(FILE * output, Classifier::WeightedActions & wa); + void trainUnbatched(); + void trainBatched(); public : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 3f181dff14b92cad7a2866a8df40198294427e5d..4f6aca255e5a6598dfc34dc1db830d007d28b9bd 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -5,9 +5,9 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config) { } -void Trainer::train() +void Trainer::trainUnbatched() { - int nbIter = 20; + int nbIter = 5; fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); @@ -24,7 +24,7 @@ void Trainer::train() //fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); - std::string neededActionName = classifier->oracle->getAction(config); + std::string neededActionName = classifier->getOracleAction(config); auto weightedActions = classifier->weightActions(config, neededActionName); //printWeightedActions(stderr, weightedActions); std::string & predictedAction = weightedActions[0].second; @@ -48,6 +48,67 @@ void Trainer::train() } } +void Trainer::trainBatched() +{ + using FD = FeatureModel::FeatureDescription; + using Example = std::pair<int, FD>; + + std::map<Classifier*, std::vector<Example> > examples; + + while (!config.isFinal()) + { + TapeMachine::State * currentState = tm.getCurrentState(); + Classifier * classifier = currentState->classifier; + classifier->initClassifier(config); + + int neededActionIndex = classifier->getOracleActionIndex(config); + std::string neededActionName = classifier->getActionName(neededActionIndex); + + examples[classifier].emplace_back(Example(neededActionIndex, classifier->getFeatureDescription(config))); + + TapeMachine::Transition * transition = tm.getTransition(neededActionName); + tm.takeTransition(transition); + config.moveHead(transition->headMvt); + } + + int nbIter = 20; + int batchSize = 50; + + for (int i = 0; i < nbIter; i++) + { + std::map< std::string, std::pair<int, int> > nbExamples; + + for(auto & it : examples) + { + int nbBatches = (it.second.size() / batchSize) + (it.second.size() % batchSize ? 1 : 0); + + for(int numBatch = 0; numBatch < nbBatches; numBatch++) + { + int currentBatchSize = std::min<int>(batchSize, it.second.size() - (numBatch * batchSize)); + + auto batchStart = it.second.begin() + numBatch * batchSize; + auto batchEnd = batchStart + currentBatchSize; + + int nbCorrect = it.first->trainOnBatch(batchStart, batchEnd); + + nbExamples[it.first->name].first += currentBatchSize; + nbExamples[it.first->name].second += nbCorrect; + } + } + + fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter); + for(auto & it : nbExamples) + fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first); + + } +} + +void Trainer::train() +{ +// trainUnbatched(); + trainBatched(); +} + void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa) { int nbCols = 80;