diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp index 05d5ef3c00b25ce5273ba2bddb199e6123c1fb09..14f174d6e4375ea2503f73f78c8eeb45586fc60d 100644 --- a/MLP/include/MLP.hpp +++ b/MLP/include/MLP.hpp @@ -60,13 +60,20 @@ class MLP dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x); inline dynet::Expression activate(dynet::Expression h, Activation f); void printParameters(FILE * output); + void saveStruct(const std::string & filename); + void saveParameters(const std::string & filename); + void loadStruct(const std::string & filename); + void loadParameters(const std::string & filename); + void load(const std::string & filename); public : MLP(std::vector<Layer> layers); + MLP(const std::string & filename); std::vector<float> predict(FeatureModel::FeatureDescription & fd, int goldClass); int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end); + void save(const std::string & filename); }; #endif diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp index 6213bf698c7badc46b18635883f77aaa3ae064bb..852895421bf3f7e4197ee8ac7de881ff6c9e83bd 100644 --- a/MLP/src/MLP.cpp +++ b/MLP/src/MLP.cpp @@ -1,4 +1,5 @@ #include "MLP.hpp" +#include "File.hpp" #include "util.hpp" #include <dynet/param-init.h> @@ -113,9 +114,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC std::vector<dynet::Expression> expressions; for (auto & featValue : fd.values) - { expressions.emplace_back(featValue2Expression(cg, featValue)); - } dynet::Expression input = dynet::concatenate(expressions); @@ -310,3 +309,79 @@ int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescriptio return nbCorrect; } +void MLP::save(const std::string & filename) +{ + saveStruct(filename); + saveParameters(filename); +} + +void MLP::saveStruct(const std::string & filename) +{ + File file(filename, "w"); + FILE * fd = file.getDescriptor(); + + for (auto & layer : layers) + { + fprintf(fd, "Layer : %d %d %s %.2f\n", layer.input_dim, layer.output_dim, activation2str(layer.activation).c_str(), layer.dropout_rate); + } +} + +void MLP::saveParameters(const std::string & filename) +{ + dynet::TextFileSaver s(filename, true); + std::string prefix("Layer_"); + + for(unsigned int i = 0; i < parameters.size(); i++) + { + s.save(parameters[i][0], prefix + std::to_string(i) + "_W"); + s.save(parameters[i][1], prefix + std::to_string(i) + "_b"); + } +} + +void MLP::load(const std::string & filename) +{ + loadStruct(filename); + loadParameters(filename); +} + +void MLP::loadStruct(const std::string & filename) +{ + File file(filename, "r"); + FILE * fd = file.getDescriptor(); + + char activation[1024]; + int input; + int output; + float dropout; + + while (fscanf(fd, "Layer : %d %d %s %f\n", &input, &output, activation, &dropout) == 4) + layers.emplace_back(input, output, dropout, str2activation(activation)); + + checkLayersCompatibility(); + + for (auto & layer : layers) + addLayerToModel(layer); +} + +void MLP::loadParameters(const std::string & filename) +{ + dynet::TextFileLoader loader(filename); + std::string prefix("Layer_"); + + for(unsigned int i = 0; i < parameters.size(); i++) + { + parameters[i][0] = loader.load_param(model, prefix + std::to_string(i) + "_W"); + parameters[i][1] = loader.load_param(model, prefix + std::to_string(i) + "_b"); + } +} + +MLP::MLP(const std::string & filename) +: trainer(model, 0.001, 0.9, 0.999, 1e-8) +{ + dynet::initialize(getDefaultParams()); + + trainMode = false; + + load(filename); +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 779e59b417939ab8f9b3eab0c84759b32fcb619c..d6f2e5c673b7b998bc512ecf35eeb957e0d91966 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,6 +7,44 @@ Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config) void Decoder::decode() { + int nbIter = 1; + fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); + + for (int i = 0; i < nbIter; i++) + { + std::map< std::string, std::pair<int, int> > nbExamples; + + while (!config.isFinal()) + { + TapeMachine::State * currentState = tm.getCurrentState(); + Classifier * classifier = currentState->classifier; + + //config.printForDebug(stderr); + + //fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + + std::string neededActionName = classifier->getOracleAction(config); + auto weightedActions = classifier->weightActions(config, neededActionName); + //Classifier::printWeightedActions(stderr, weightedActions); + std::string & predictedAction = weightedActions[0].second; + + nbExamples[classifier->name].first++; + if(predictedAction == neededActionName) + nbExamples[classifier->name].second++; + + //fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str()); + + TapeMachine::Transition * transition = tm.getTransition(neededActionName); + tm.takeTransition(transition); + config.moveHead(transition->headMvt); + } + + fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter); + for(auto & it : nbExamples) + fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first); + + config.reset(); + } } diff --git a/maca_common/include/Dict.hpp b/maca_common/include/Dict.hpp index 996aaa43929e2423fecac61113e2f78f666de9ac..7415ba4fdf314535150ca186b304b3b76aa318a3 100644 --- a/maca_common/include/Dict.hpp +++ b/maca_common/include/Dict.hpp @@ -59,6 +59,7 @@ class Dict std::vector<float> * getValue(const std::string & s); std::vector<float> * getNullValue(); int getDimension(); + void printForDebug(FILE * output); }; #endif diff --git a/maca_common/src/Dict.cpp b/maca_common/src/Dict.cpp index 125f24c4802c4d8ae32e819ca1c75ddb15775342..ea355dafcfe7584a04292011cd0a17a1c04631ab 100644 --- a/maca_common/src/Dict.cpp +++ b/maca_common/src/Dict.cpp @@ -71,10 +71,11 @@ Dict::Dict(Policy policy, const std::string & filename) if(this->policy == Policy::FromZero) return; - while(fscanf(fd, "%s", b1) != 1) + while(fscanf(fd, "%s", b1) == 1) { std::string entry = b1; - str2vec.emplace(entry, std::vector<float>()); + //str2vec.emplace(entry, std::vector<float>()); + str2vec[entry] = std::vector<float>(); auto & vec = str2vec[entry]; // For OneHot we only write the index @@ -195,3 +196,8 @@ int Dict::getDimension() return dimension; } +void Dict::printForDebug(FILE * output) +{ + fprintf(output, "Dict name \'%s\' nbElems = %lu\n", name.c_str(), str2vec.size()); +} + diff --git a/tape_machine/include/Classifier.hpp b/tape_machine/include/Classifier.hpp index 7ee75bb35df98a1a92269311220c2666d0345aca..17737346422e446f1a6f25b6973e874bfa3d779c 100644 --- a/tape_machine/include/Classifier.hpp +++ b/tape_machine/include/Classifier.hpp @@ -25,6 +25,7 @@ class Classifier private : + bool trainMode; Type type; std::unique_ptr<FeatureModel> fm; std::unique_ptr<ActionSet> as; @@ -36,7 +37,7 @@ class Classifier static void printWeightedActions(FILE * output, WeightedActions & wa); static Type str2type(const std::string & filename); - Classifier(const std::string & filename); + Classifier(const std::string & filename, bool trainMode); WeightedActions weightActions(Config & config, const std::string & goldAction); FeatureModel::FeatureDescription getFeatureDescription(Config & config); std::string getOracleAction(Config & config); @@ -44,6 +45,7 @@ class Classifier int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end); std::string getActionName(int actionIndex); void initClassifier(Config & config); + void save(const std::string & filename); }; #endif diff --git a/tape_machine/include/TapeMachine.hpp b/tape_machine/include/TapeMachine.hpp index ac562f7fea42cdf89c320b04a3b39ec8d9b4409b..e4504698788ac468317da966f76f2e67a0d53ba5 100644 --- a/tape_machine/include/TapeMachine.hpp +++ b/tape_machine/include/TapeMachine.hpp @@ -29,9 +29,11 @@ class TapeMachine private : + bool trainMode; std::map< std::string, std::unique_ptr<Classifier> > str2classifier; std::map< std::string, std::unique_ptr<State> > str2state; State * currentState; + std::vector<Classifier*> classifiers; public : @@ -39,10 +41,11 @@ class TapeMachine public : - TapeMachine(const std::string & filename); + TapeMachine(const std::string & filename, bool trainMode); State * getCurrentState(); Transition * getTransition(const std::string & action); void takeTransition(Transition * transition); + std::vector<Classifier*> & getClassifiers(); }; #endif diff --git a/tape_machine/src/Classifier.cpp b/tape_machine/src/Classifier.cpp index 9e3ff5efa694e79e1fc17a80ecadf53f3595ab20..f6fa917cd9621bcb79d36571cf1778351698c2bb 100644 --- a/tape_machine/src/Classifier.cpp +++ b/tape_machine/src/Classifier.cpp @@ -2,8 +2,10 @@ #include "File.hpp" #include "util.hpp" -Classifier::Classifier(const std::string & filename) +Classifier::Classifier(const std::string & filename, bool trainMode) { + this->trainMode = trainMode; + auto badFormatAndAbort = [&filename](const char * errInfo) { fprintf(stderr, "ERROR (%s) : file %s bad format. Aborting.\n", errInfo, filename.c_str()); @@ -85,6 +87,12 @@ void Classifier::initClassifier(Config & config) if(mlp.get()) return; + if(!trainMode) + { + mlp.reset(new MLP("toto.txt")); + return; + } + int nbInputs = 0; int nbHidden = 200; int nbOutputs = as->actions.size(); @@ -138,3 +146,8 @@ void Classifier::printWeightedActions(FILE * output, WeightedActions & wa) fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : ""); } +void Classifier::save(const std::string & filename) +{ + mlp->save(filename); +} + diff --git a/tape_machine/src/TapeMachine.cpp b/tape_machine/src/TapeMachine.cpp index da26748e5dfea0072a8613e0ec7514d37450f604..1356183b4a5c30b7060aab873e446010cac11525 100644 --- a/tape_machine/src/TapeMachine.cpp +++ b/tape_machine/src/TapeMachine.cpp @@ -3,7 +3,7 @@ #include "util.hpp" #include <cstring> -TapeMachine::TapeMachine(const std::string & filename) +TapeMachine::TapeMachine(const std::string & filename, bool trainMode) { auto badFormatAndAbort = [&filename](const std::string & errInfo) { @@ -12,6 +12,8 @@ TapeMachine::TapeMachine(const std::string & filename) exit(1); }; + this->trainMode = trainMode; + File file(filename, "r"); FILE * fd = file.getDescriptor(); @@ -35,7 +37,9 @@ TapeMachine::TapeMachine(const std::string & filename) if(fscanf(fd, "%s %s\n", buffer, buffer2) != 2) badFormatAndAbort(ERRINFO); - str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2))); + str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2, trainMode))); + + classifiers.emplace_back(str2classifier[buffer].get()); } // Reading %STATES @@ -124,3 +128,8 @@ void TapeMachine::takeTransition(Transition * transition) currentState = transition->dest; } +std::vector<Classifier*> & TapeMachine::getClassifiers() +{ + return classifiers; +} + diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3da1b40b4d359fd6f01a9ed82b936a65f33de717..c76bc3dd143b8ae25f851b7c85e5d507f3232985 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -3,3 +3,7 @@ FILE(GLOB SOURCES src/*.cpp) add_executable(test_train src/test_train.cpp) target_link_libraries(test_train tape_machine) target_link_libraries(test_train trainer) + +add_executable(test_decode src/test_decode.cpp) +target_link_libraries(test_decode tape_machine) +target_link_libraries(test_decode decoder) diff --git a/tests/src/test_decode.cpp b/tests/src/test_decode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2237a734bfc24d7ff1993e09fce1a100c3f15044 --- /dev/null +++ b/tests/src/test_decode.cpp @@ -0,0 +1,32 @@ +#include <cstdio> +#include <cstdlib> +#include "MCD.hpp" +#include "Config.hpp" +#include "TapeMachine.hpp" +#include "Decoder.hpp" + +void printUsageAndExit(char * argv[]) +{ + fprintf(stderr, "USAGE : %s mcd inputFile tm\n", *argv); + exit(1); +} + +int main(int argc, char * argv[]) +{ + if (argc != 4) + printUsageAndExit(argv); + + MCD mcd(argv[1]); + Config config(mcd); + + TapeMachine tapeMachine(argv[3], false); + + config.readInput(argv[2]); + + Decoder decoder(tapeMachine, mcd, config); + + decoder.decode(); + + return 0; +} + diff --git a/tests/src/test_train.cpp b/tests/src/test_train.cpp index 3d48cd144ae9398f569b1c7109d91fc2e87ff714..d2d17665ccb76b21c48ab889910c570e0deb5096 100644 --- a/tests/src/test_train.cpp +++ b/tests/src/test_train.cpp @@ -19,7 +19,7 @@ int main(int argc, char * argv[]) MCD mcd(argv[1]); Config config(mcd); - TapeMachine tapeMachine(argv[3]); + TapeMachine tapeMachine(argv[3], true); config.readInput(argv[2]); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d20ee64f128973763658c1f412982f02c786d49d..fdb364c3a77fd97d03004442fe00da68ad9342a1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -7,7 +7,7 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config) void Trainer::trainUnbatched() { - int nbIter = 5; + int nbIter = 20; fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); @@ -26,7 +26,7 @@ void Trainer::trainUnbatched() std::string neededActionName = classifier->getOracleAction(config); auto weightedActions = classifier->weightActions(config, neededActionName); - //printWeightedActions(stderr, weightedActions); + //Classifier::printWeightedActions(stderr, weightedActions); std::string & predictedAction = weightedActions[0].second; nbExamples[classifier->name].first++; @@ -46,6 +46,10 @@ void Trainer::trainUnbatched() config.reset(); } + + auto & classifiers = tm.getClassifiers(); + for(Classifier * cla : classifiers) + cla->save("toto.txt"); } void Trainer::trainBatched() @@ -101,13 +105,16 @@ void Trainer::trainBatched() fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter); for(auto & it : nbExamples) fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first); - } + + auto & classifiers = tm.getClassifiers(); + for(Classifier * cla : classifiers) + cla->save("toto.txt"); } void Trainer::train() { -// trainUnbatched(); + //trainUnbatched(); trainBatched(); }