diff --git a/MLP/include/MLP.hpp b/MLP/include/MLP.hpp index 5723d753c213dc6cce17271d5898f62f287db35b..28c78786cce402808faa0d05177e1fa27da871e9 100644 --- a/MLP/include/MLP.hpp +++ b/MLP/include/MLP.hpp @@ -72,7 +72,7 @@ class MLP MLP(std::vector<Layer> layers); MLP(const std::string & filename); - std::vector<float> predict(FeatureModel::FeatureDescription & fd, int goldClass); + std::vector<float> predict(FeatureModel::FeatureDescription & fd); int trainOnBatch(Examples & examples, int start, int end); int getScoreOnBatch(Examples & examples, int start, int end); diff --git a/MLP/src/MLP.cpp b/MLP/src/MLP.cpp index 0222e96b20e58f5d599598876951e5b1ad0f614c..e2789b7e4c1a5fe4d223e18c64757bcd5563ac39 100644 --- a/MLP/src/MLP.cpp +++ b/MLP/src/MLP.cpp @@ -107,7 +107,7 @@ MLP::Layer::Layer(int input_dim, int output_dim, this->activation = activation; } -std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldClass) +std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd) { dynet::ComputationGraph cg; @@ -120,12 +120,6 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC dynet::Expression output = run(cg, input); - if(trainMode) - { - cg.backward(pickneglogsoftmax(output, goldClass)); - trainer.update(); - } - return as_vector(cg.forward(output)); } diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index d6f2e5c673b7b998bc512ecf35eeb957e0d91966..9f306ffe8c4cb0a18df2a7ec4d7ae35c2cb7c783 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -1,4 +1,5 @@ #include "Decoder.hpp" +#include "util.hpp" Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config) : tm(tm), mcd(mcd), config(config) @@ -7,44 +8,29 @@ Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config) void Decoder::decode() { - int nbIter = 1; - - fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str()); - - for (int i = 0; i < nbIter; i++) + while (!config.isFinal()) { - std::map< std::string, std::pair<int, int> > nbExamples; - - while (!config.isFinal()) - { - TapeMachine::State * currentState = tm.getCurrentState(); - Classifier * classifier = currentState->classifier; - - //config.printForDebug(stderr); + TapeMachine::State * currentState = tm.getCurrentState(); + Classifier * classifier = currentState->classifier; - //fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); + //config.printForDebug(stderr); - std::string neededActionName = classifier->getOracleAction(config); - auto weightedActions = classifier->weightActions(config, neededActionName); - //Classifier::printWeightedActions(stderr, weightedActions); - std::string & predictedAction = weightedActions[0].second; + //fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str()); - nbExamples[classifier->name].first++; - if(predictedAction == neededActionName) - nbExamples[classifier->name].second++; + auto weightedActions = classifier->weightActions(config); + //Classifier::printWeightedActions(stderr, weightedActions); + std::string & predictedAction = weightedActions[0].second; - //fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str()); + Action * action = classifier->getAction(predictedAction); + if(!action->appliable(config)) + fprintf(stderr, "WARNING (%s) : action \'%s\' is not appliable.\n", ERRINFO, predictedAction.c_str()); + action->apply(config); - TapeMachine::Transition * transition = tm.getTransition(neededActionName); - tm.takeTransition(transition); - config.moveHead(transition->headMvt); - } - - fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter); - for(auto & it : nbExamples) - fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first); - - config.reset(); + TapeMachine::Transition * transition = tm.getTransition(predictedAction); + tm.takeTransition(transition); + config.moveHead(transition->headMvt); } + + config.printAsOutput(stdout); } diff --git a/maca_common/src/File.cpp b/maca_common/src/File.cpp index b0c2d0845a79b4050e7e109bf0041e0ff821f374..3f8eaf2b60206b86aa241d8629016761aa803dd0 100644 --- a/maca_common/src/File.cpp +++ b/maca_common/src/File.cpp @@ -32,7 +32,7 @@ File::File(const std::string & filename, const std::string & mode) if (mode != "r" && mode != "w") { - printf("\"%s\" is an invalid mode when opening a file\n", mode.c_str()); + fprintf(stderr, "ERROR (%s) : \"%s\" is an invalid mode when opening a file\n", ERRINFO, mode.c_str()); exit(1); } @@ -52,7 +52,7 @@ File::File(const std::string & filename, const std::string & mode) if (!file) { - printf("Cannot open file %s\n", filename.c_str()); + fprintf(stderr, "ERROR (%s) : cannot open file %s\n", ERRINFO, filename.c_str()); exit(1); } @@ -125,7 +125,7 @@ void File::rewind() if (!file) { - printf("Cannot open file %s\n", filename.c_str()); + fprintf(stderr, "ERROR (%s) : Cannot open file %s\n", ERRINFO, filename.c_str()); exit(1); } diff --git a/tape_machine/include/ActionSet.hpp b/tape_machine/include/ActionSet.hpp index e00f22a0bd7d2fa8ee330149bbfdcd5439c9cb05..c4d590bac51dc2f54e80e65488eae051ee467c39 100644 --- a/tape_machine/include/ActionSet.hpp +++ b/tape_machine/include/ActionSet.hpp @@ -18,6 +18,7 @@ class ActionSet void printForDebug(FILE * output); int getActionIndex(const std::string & name); std::string getActionName(int actionIndex); + Action * getAction(const std::string & name); }; #endif diff --git a/tape_machine/include/Classifier.hpp b/tape_machine/include/Classifier.hpp index fc4da733310eb0756416460ec76d3b73785ad25d..09a0f218c984ea2c5e8b8918da6d3bc370243108 100644 --- a/tape_machine/include/Classifier.hpp +++ b/tape_machine/include/Classifier.hpp @@ -43,13 +43,14 @@ class Classifier static Type str2type(const std::string & filename); Classifier(const std::string & filename, bool trainMode); - WeightedActions weightActions(Config & config, const std::string & goldAction); + WeightedActions weightActions(Config & config); FeatureModel::FeatureDescription getFeatureDescription(Config & config); std::string getOracleAction(Config & config); int getOracleActionIndex(Config & config); int getScoreOnBatch(MLP::Examples & examples, int start, int end); int trainOnBatch(MLP::Examples & examples, int start, int end); std::string getActionName(int actionIndex); + Action * getAction(const std::string & name); void initClassifier(Config & config); void save(); }; diff --git a/tape_machine/include/Config.hpp b/tape_machine/include/Config.hpp index 352f305646588e7b253a371827f64736d248d63e..55039a267d04109f0c51852bcb6cc78c381886b0 100644 --- a/tape_machine/include/Config.hpp +++ b/tape_machine/include/Config.hpp @@ -21,6 +21,7 @@ class Config std::vector<std::string> & getTapeByInputCol(int col); void readInput(const std::string & filename); void printForDebug(FILE * output); + void printAsOutput(FILE * output); void moveHead(int mvt); bool isFinal(); void reset(); diff --git a/tape_machine/src/ActionSet.cpp b/tape_machine/src/ActionSet.cpp index 8932d8ca03fbbbbd3138a107e171d713e537ae91..e43d5fcdcf9ad453de2b7c19e15fa33c98a509e8 100644 --- a/tape_machine/src/ActionSet.cpp +++ b/tape_machine/src/ActionSet.cpp @@ -54,3 +54,8 @@ std::string ActionSet::getActionName(int actionIndex) return ""; } +Action * ActionSet::getAction(const std::string & name) +{ + return &actions[getActionIndex(name)]; +} + diff --git a/tape_machine/src/Classifier.cpp b/tape_machine/src/Classifier.cpp index 250e38fa2e01fe690dab70ad9df65230e6328580..a187d73e7ca7151fbfbe28f65a1e6f618ac4c476 100644 --- a/tape_machine/src/Classifier.cpp +++ b/tape_machine/src/Classifier.cpp @@ -64,14 +64,12 @@ Classifier::Type Classifier::str2type(const std::string & s) return Type::Prediction; } -Classifier::WeightedActions Classifier::weightActions(Config & config, const std::string & goldAction) +Classifier::WeightedActions Classifier::weightActions(Config & config) { initClassifier(config); - int actionIndex = as->getActionIndex(goldAction); - auto fd = fm->getFeatureDescription(config); - auto scores = mlp->predict(fd, actionIndex); + auto scores = mlp->predict(fd); WeightedActions result; @@ -94,7 +92,7 @@ void Classifier::initClassifier(Config & config) if(!trainMode) { - mlp.reset(new MLP("toto.txt")); + mlp.reset(new MLP(modelFilename)); return; } @@ -166,3 +164,8 @@ void Classifier::save() mlp->save(modelFilename); } +Action * Classifier::getAction(const std::string & name) +{ + return as->getAction(name); +} + diff --git a/tape_machine/src/Config.cpp b/tape_machine/src/Config.cpp index 7f81823b2ade7ae4142b5b84f9c855217602e1c8..0a2cb66d98b0862ceffd7055570716a041d6d825 100644 --- a/tape_machine/src/Config.cpp +++ b/tape_machine/src/Config.cpp @@ -36,6 +36,15 @@ void Config::readInput(const std::string & filename) mcd.getDictOfInputCol(col)->getValue(tape.back()); } } + + // Making all tapes the same size + unsigned int maxTapeSize = 0; + for(auto & tape : tapes) + maxTapeSize = std::max<unsigned int>(maxTapeSize, tape.size()); + + for(auto & tape : tapes) + while(tape.size() < maxTapeSize) + tape.emplace_back(); } void Config::printForDebug(FILE * output) @@ -50,6 +59,13 @@ void Config::printForDebug(FILE * output) } } +void Config::printAsOutput(FILE * output) +{ + for (unsigned int i = 0; i < tapes[0].size(); i++) + for (unsigned int j = 0; j < tapes.size(); j++) + fprintf(output, "%s%s", tapes[j][i].c_str(), j == tapes.size()-1 ? "\n" : "\t"); +} + void Config::moveHead(int mvt) { head += mvt; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 043734e17a56c58d02b11e3399d3ac1007117502..3688ef2d116f10198e439ed3a47e9a8783473567 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -8,4 +8,4 @@ target_link_libraries(test_train ${Boost_PROGRAM_OPTIONS_LIBRARY}) add_executable(test_decode src/test_decode.cpp) target_link_libraries(test_decode tape_machine) target_link_libraries(test_decode decoder) -target_link_libraries(test_train ${Boost_PROGRAM_OPTIONS_LIBRARY}) +target_link_libraries(test_decode ${Boost_PROGRAM_OPTIONS_LIBRARY}) diff --git a/tests/src/test_decode.cpp b/tests/src/test_decode.cpp index 2237a734bfc24d7ff1993e09fce1a100c3f15044..2fc5c54c3e66995c8bea0cb1d7162637b8d6cb2b 100644 --- a/tests/src/test_decode.cpp +++ b/tests/src/test_decode.cpp @@ -1,27 +1,79 @@ #include <cstdio> #include <cstdlib> +#include <boost/program_options.hpp> #include "MCD.hpp" #include "Config.hpp" #include "TapeMachine.hpp" #include "Decoder.hpp" -void printUsageAndExit(char * argv[]) -{ - fprintf(stderr, "USAGE : %s mcd inputFile tm\n", *argv); - exit(1); +namespace po = boost::program_options; + +po::options_description getOptionsDescription() +{ + po::options_description desc("Command-Line Arguments "); + + po::options_description req("Required"); + req.add_options() + ("tm", po::value<std::string>()->required(), + "File describing the Tape Machine to use") + ("mcd", po::value<std::string>()->required(), + "MCD file that describes the input") + ("input,I", po::value<std::string>()->required(), + "Input file formated according to the MCD"); + + po::options_description opt("Optional"); + opt.add_options() + ("help,h", "Produce this help message"); + + desc.add(req).add(opt); + + return desc; +} + +po::variables_map checkOptions(po::options_description & od, int argc, char ** argv) +{ + po::variables_map vm; + + try {po::store(po::parse_command_line(argc, argv, od), vm);} + catch(std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + od.print(std::cerr); + exit(1); + } + + if (vm.count("help")) + { + std::cout << od << "\n"; + exit(0); + } + + try {po::notify(vm);} + catch(std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + od.print(std::cerr); + exit(1); + } + + return vm; } int main(int argc, char * argv[]) { - if (argc != 4) - printUsageAndExit(argv); + auto od = getOptionsDescription(); - MCD mcd(argv[1]); - Config config(mcd); + po::variables_map vm = checkOptions(od, argc, argv); - TapeMachine tapeMachine(argv[3], false); + std::string mcdFilename = vm["mcd"].as<std::string>(); + std::string tmFilename = vm["tm"].as<std::string>(); + std::string inputFilename = vm["input"].as<std::string>(); - config.readInput(argv[2]); + TapeMachine tapeMachine(tmFilename, false); + + MCD mcd(mcdFilename); + Config config(mcd); + config.readInput(inputFilename); Decoder decoder(tapeMachine, mcd, config);