diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b655eec25aea029a63b69c9147a13bf8f7a0163..1f45987431919a070295232ecb1a87e339325a12 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ set(CMAKE_VERBOSE_MAKEFILE 0) set(CMAKE_CXX_STANDARD 11) if(NOT CMAKE_BUILD_TYPE) +# set(CMAKE_BUILD_TYPE Debug) set(CMAKE_BUILD_TYPE Release) endif() diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp index 5c61498bae0450509eb9157a76b7e0b14f285c51..8c0d685b50f1b46b0bef385f2c6c6406adf06551 100644 --- a/decoder/include/Decoder.hpp +++ b/decoder/include/Decoder.hpp @@ -2,7 +2,7 @@ #define DECODER__H #include "TapeMachine.hpp" -#include "MCD.hpp" +#include "BD.hpp" #include "Config.hpp" class Decoder @@ -10,12 +10,12 @@ class Decoder private : TapeMachine & tm; - MCD & mcd; + BD & bd; Config & config; public : - Decoder(TapeMachine & tm, MCD & mcd, Config & config); + Decoder(TapeMachine & tm, BD & bd, Config & config); void decode(); }; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index e299b292384b2cbbb3938cf5d7e31074ef2dead3..8ca9f5508cccf6e646e2665648bf5f9d1e1ea118 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -1,8 +1,8 @@ #include "Decoder.hpp" #include "util.hpp" -Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config) -: tm(tm), mcd(mcd), config(config) +Decoder::Decoder(TapeMachine & tm, BD & bd, Config & config) +: tm(tm), bd(bd), config(config) { } diff --git a/maca_common/include/Dict.hpp b/maca_common/include/Dict.hpp index d9b130fb9859cc55c361448251ed6bce30d880f7..1897833747b63729ef8453a334a0a7c09852cd70 100644 --- a/maca_common/include/Dict.hpp +++ b/maca_common/include/Dict.hpp @@ -61,11 +61,15 @@ class Dict std::vector<float> * addEntry(const std::string & s); Dict(Policy policy, const std::string & filename); + Dict(const std::string & name, int dimension, Mode mode); public : static Dict * getDict(Policy policy, const std::string & filename); + static Dict * getDict(const std::string & name); + static void readDicts(const std::string & filename, bool trainMode, const std::string & expPath); ~Dict(); + static void saveDicts(const std::string & directory); void save(); std::vector<float> * getValue(const std::string & s); const std::string * getStr(const std::string & s); diff --git a/maca_common/src/Dict.cpp b/maca_common/src/Dict.cpp index a3906914c8239882477e54471eae97161c042313..3cb38b1aa9d35eec429e074a6a6fee837ca99173 100644 --- a/maca_common/src/Dict.cpp +++ b/maca_common/src/Dict.cpp @@ -42,6 +42,20 @@ const char * Dict::mode2str(Mode mode) return "Embeddings"; } +Dict::Dict(const std::string & name, int dimension, Mode mode) +{ + this->policy = Policy::FromZero; + this->filename = name; + this->name = name; + this->oneHotIndex = 0; + + this->mode = mode; + this->dimension = dimension; + + addEntry(nullValueStr); + addEntry(unknownValueStr); +} + Dict::Dict(Policy policy, const std::string & filename) { auto badFormatAndAbort = [&](std::string errInfo) @@ -94,7 +108,6 @@ Dict::Dict(Policy policy, const std::string & filename) while(fscanf(fd, "%s", b1) == 1) { std::string entry = b1; - //str2vec.emplace(entry, std::vector<float>()); str2vec[entry] = std::vector<float>(); auto & vec = str2vec[entry]; @@ -119,6 +132,15 @@ Dict::Dict(Policy policy, const std::string & filename) } } +void Dict::saveDicts(const std::string & directory) +{ + for (auto & it : str2dict) + { + it.second->filename = directory + it.second->name + ".dict"; + it.second->save(); + } +} + void Dict::save() { // If policy is Final, we didn't change any entry so no need to rewrite the file @@ -264,7 +286,6 @@ void Dict::initEmbeddingFromFasttext(const std::string & s, std::vector<float> & Dict::~Dict() { - save(); } std::vector<float> * Dict::getUnknownValue() @@ -309,11 +330,77 @@ Dict * Dict::getDict(Policy policy, const std::string & filename) if(it != str2dict.end()) return it->second.get(); - str2dict.insert(std::make_pair(filename, std::unique_ptr<Dict>(new Dict(policy, filename)))); + Dict * dict = new Dict(policy, filename); + + str2dict.insert(std::make_pair(dict->name, std::unique_ptr<Dict>(dict))); return str2dict[filename].get(); } +Dict * Dict::getDict(const std::string & name) +{ + auto it = str2dict.find(name); + if(it != str2dict.end()) + return it->second.get(); + + fprintf(stderr, "ERROR (%s) : dictionary \'%s\' does not exists. Aborting.\n", ERRINFO, name.c_str()); + exit(1); + + return nullptr; +} + +void Dict::readDicts(const std::string & filename, bool trainMode, const std::string & expPath) +{ + char buffer[1024]; + char name[1024]; + char modeStr[1024]; + int dim; + + File file(filename, "r"); + FILE * fd = file.getDescriptor(); + + while(fscanf(fd, "%[^\n]\n", buffer) == 1) + { + if(buffer[0] == '#') + continue; + + if(trainMode) + { + if(sscanf(buffer, "%s %d %s", name, &dim, modeStr) != 3) + { + fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer); + exit(1); + } + + auto it = str2dict.find(name); + if(it != str2dict.end()) + { + fprintf(stderr, "ERROR (%s) : dictionary \'%s\' already exists. Aborting.\n", ERRINFO, name); + exit(1); + } + + str2dict.insert(std::make_pair(name, std::unique_ptr<Dict>(new Dict(name, dim, str2mode(modeStr))))); + } + else + { + if(sscanf(buffer, "%s", name) != 1) + { + fprintf(stderr, "ERROR (%s) : line \'%s\' do not describe a dictionary. Aborting.\n", ERRINFO, buffer); + exit(1); + } + + auto it = str2dict.find(name); + if(it != str2dict.end()) + { + fprintf(stderr, "ERROR (%s) : dictionary \'%s\' already exists. Aborting.\n", ERRINFO, name); + exit(1); + } + + getDict(Policy::Final, expPath + name + ".dict"); + } + } +} + int Dict::getDimension() { return dimension; diff --git a/tape_machine/include/MCD.hpp b/tape_machine/include/BD.hpp similarity index 70% rename from tape_machine/include/MCD.hpp rename to tape_machine/include/BD.hpp index f28a4345ed1bf0f7a201af337cb3231193c40c8f..311a5925fd7748bedf1cfb03026a7c20700ae272 100644 --- a/tape_machine/include/MCD.hpp +++ b/tape_machine/include/BD.hpp @@ -1,11 +1,11 @@ -#ifndef MCD__H -#define MCD__H +#ifndef BD__H +#define BD__H #include <map> #include <memory> #include "Dict.hpp" -class MCD +class BD { struct Line { @@ -16,7 +16,7 @@ class MCD bool mustPrint; bool isKnown; - Line(int num, std::string name, std::string dictFilename, std::string dictPolicy, int inputColumn, bool mustPrint, bool isKnown); + Line(int num, std::string name, std::string dictName, std::string dictPolicy, int inputColumn, bool mustPrint, bool isKnown); }; private : @@ -25,21 +25,20 @@ class MCD std::map<int, Line*> num2line; std::map<std::string, Line*> name2line; std::map<int, Line*> col2line; - int nbInputCols; public : - MCD(const std::string & filename); + BD(const std::string & BDfilename, const std::string & MCDfilename, const std::string & expPath); Dict * getDictOfLine(int num); Dict * getDictOfLine(const std::string & name); Dict * getDictOfInputCol(int col); int getLineOfName(const std::string & name); int getLineOfInputCol(int col); - int getNbInputColumns(); int getNbLines(); bool mustPrintLine(int index); const std::string & getNameOfLine(int line); bool lineIsKnown(int line); + bool hasLineOfInputCol(int col); }; #endif diff --git a/tape_machine/include/Classifier.hpp b/tape_machine/include/Classifier.hpp index 219f55018769f0afbf3a1ac6a222971d2c0d94e2..6153cee14b3b63e8c480f9f4fdc30962ca821ac4 100644 --- a/tape_machine/include/Classifier.hpp +++ b/tape_machine/include/Classifier.hpp @@ -31,18 +31,13 @@ class Classifier std::unique_ptr<ActionSet> as; std::unique_ptr<MLP> mlp; Oracle * oracle; - std::string modelFilename; - - private : - - void save(const std::string & filename); public : static void printWeightedActions(FILE * output, WeightedActions & wa, int threshold = 5); static Type str2type(const std::string & filename); - Classifier(const std::string & filename, bool trainMode); + Classifier(const std::string & filename, bool trainMode, const std::string & expPath); WeightedActions weightActions(Config & config); FeatureModel::FeatureDescription getFeatureDescription(Config & config); std::string getOracleAction(Config & config); @@ -52,7 +47,7 @@ class Classifier std::string getActionName(int actionIndex); Action * getAction(const std::string & name); void initClassifier(Config & config); - void save(); + void save(const std::string & modelFilename); bool needsTrain(); void printTopology(FILE * output); }; diff --git a/tape_machine/include/Config.hpp b/tape_machine/include/Config.hpp index 5f92535c675fd57d5ad9c4c4f13f3d0270624075..60f46b49ad05206e8fa8ae4616251ba17b63471d 100644 --- a/tape_machine/include/Config.hpp +++ b/tape_machine/include/Config.hpp @@ -2,7 +2,7 @@ #define CONFIG__H #include <vector> -#include "MCD.hpp" +#include "BD.hpp" class Config { @@ -19,17 +19,18 @@ class Config public : - MCD & mcd; + BD & bd; std::vector<Tape> tapes; std::vector<int> stack; int head; std::map< std::string, std::vector<std::string> > actionHistory; std::string * currentStateName; std::string inputFilename; + std::string expPath; public : - Config(MCD & mcd); + Config(BD & bd, const std::string & expPath); Tape & getTape(const std::string & name); Tape & getTapeByInputCol(int col); void readInput(const std::string & filename); diff --git a/tape_machine/include/TapeMachine.hpp b/tape_machine/include/TapeMachine.hpp index e4504698788ac468317da966f76f2e67a0d53ba5..28f2cd68e05073a6b03abb47bba24fd3c9260dea 100644 --- a/tape_machine/include/TapeMachine.hpp +++ b/tape_machine/include/TapeMachine.hpp @@ -41,7 +41,7 @@ class TapeMachine public : - TapeMachine(const std::string & filename, bool trainMode); + TapeMachine(const std::string & filename, bool trainMode, const std::string & expPath); State * getCurrentState(); Transition * getTransition(const std::string & action); void takeTransition(Transition * transition); diff --git a/tape_machine/src/BD.cpp b/tape_machine/src/BD.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8127b560f36f5be387b265c12bedb69b29148c --- /dev/null +++ b/tape_machine/src/BD.cpp @@ -0,0 +1,187 @@ +#include "BD.hpp" +#include "File.hpp" +#include "util.hpp" + +BD::Line::Line(int num, std::string name, std::string dictName, + std::string dictPolicy, int inputColumn, bool mustPrint, bool isKnown) +{ + this->dict = Dict::getDict(Dict::str2policy(dictPolicy), dictName); + this->num = num; + this->name = name; + this->inputColumn = inputColumn; + this->mustPrint = mustPrint; + this->isKnown = isKnown; +} + +BD::BD(const std::string & BDfilename, const std::string & MCDfilename, const std::string & expPath) +{ + char buffer[1024]; + char name[1024]; + char refHyp[1024]; + char dict[1024]; + char policy[1024]; + int mustPrint; + + File mcd(MCDfilename, "r"); + FILE * fd = mcd.getDescriptor(); + std::map<int, std::string> mcdCol2Str; + std::map<std::string, int> mcdStr2Col; + + while(fscanf(fd, "%[^\n]\n", buffer) == 1) + { + if(buffer[0] == '#') + continue; + + int col; + + if(sscanf(buffer, "%d %s", &col, name) != 2) + { + fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid MCD line. Aborting.\n", ERRINFO, buffer); + exit(1); + } + + if(mcdCol2Str.find(col) != mcdCol2Str.end()) + { + fprintf(stderr, "ERROR (%s) : MCD column \'%d\' already exists. Aborting.\n", ERRINFO, col); + exit(1); + } + if(mcdStr2Col.find(name) != mcdStr2Col.end()) + { + fprintf(stderr, "ERROR (%s) : MCD column \'%s\' already exists. Aborting.\n", ERRINFO, name); + exit(1); + } + + mcdCol2Str[col] = name; + mcdStr2Col[name] = col; + } + + File bd(BDfilename, "r"); + fd = bd.getDescriptor(); + + while(fscanf(fd, "%[^\n]\n", buffer) == 1) + { + if(buffer[0] == '#') + continue; + + if(sscanf(buffer, "%s %s %s %s %d", name, refHyp, dict, policy, &mustPrint) != 5) + { + fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid BD line. Aborting.\n", ERRINFO, buffer); + exit(1); + } + + if(noAccentLower(refHyp) != std::string("ref") && noAccentLower(refHyp) != std::string("hyp")) + { + fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid BD line argument. Aborting.\n", ERRINFO, refHyp); + exit(1); + } + + bool known = noAccentLower(refHyp) == std::string("ref"); + + int inputColumn = mcdStr2Col.find(name) == mcdStr2Col.end() ? -1 : mcdStr2Col[name]; + + lines.emplace_back(new Line(lines.size(), name, dict, policy, inputColumn, mustPrint == 1, known)); + Line * line = lines.back().get(); + num2line.emplace(line->num, line); + name2line.emplace(line->name, line); + col2line.emplace(line->inputColumn, line); + } +} + +Dict * BD::getDictOfLine(int num) +{ + auto it = num2line.find(num); + + if(it == num2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line number %d in BD. Aborting.\n", ERRINFO, num); + exit(1); + } + + return it->second->dict; +} + +Dict * BD::getDictOfLine(const std::string & name) +{ + auto it = name2line.find(name); + + if(it == name2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line \'%s\' in BD. Aborting.\n", ERRINFO, name.c_str()); + exit(1); + } + + return it->second->dict; +} + +Dict * BD::getDictOfInputCol(int col) +{ + auto it = col2line.find(col); + + if(it == col2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line of input column %d in BD. Aborting.\n", ERRINFO, col); + exit(1); + } + + return it->second->dict; +} + +int BD::getLineOfName(const std::string & name) +{ + auto it = name2line.find(name); + + if(it == name2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line %s in BD. Aborting.\n", ERRINFO, name.c_str()); + exit(1); + } + + return it->second->num; +} + +int BD::getLineOfInputCol(int col) +{ + auto it = col2line.find(col); + + if(it == col2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line in BD corresponding to input col %d. Aborting.\n", ERRINFO, col); + exit(1); + } + + return it->second->num; +} + +bool BD::hasLineOfInputCol(int col) +{ + return col2line.find(col) != col2line.end(); +} + +int BD::getNbLines() +{ + return lines.size(); +} + +bool BD::mustPrintLine(int index) +{ + auto it = num2line.find(index); + + if(it == num2line.end()) + { + fprintf(stderr, "ERROR (%s) : requesting line number %d in BD. Aborting.\n", ERRINFO, index); + exit(1); + } + + return it->second->mustPrint; +} + +const std::string & BD::getNameOfLine(int line) +{ + return lines[line]->name; +} + +bool BD::lineIsKnown(int line) +{ + return lines[line]->isKnown; +} + diff --git a/tape_machine/src/Classifier.cpp b/tape_machine/src/Classifier.cpp index b209eab5224ace7ebdebb706c334d28986426696..d103e5274f5fa4f3980b1331312627d3f1e58809 100644 --- a/tape_machine/src/Classifier.cpp +++ b/tape_machine/src/Classifier.cpp @@ -2,7 +2,7 @@ #include "File.hpp" #include "util.hpp" -Classifier::Classifier(const std::string & filename, bool trainMode) +Classifier::Classifier(const std::string & filename, bool trainMode, const std::string & expPath) { this->trainMode = trainMode; @@ -13,7 +13,7 @@ Classifier::Classifier(const std::string & filename, bool trainMode) exit(1); }; - File file(filename, "r"); + File file(expPath + filename, "r"); FILE * fd = file.getDescriptor(); char buffer[1024]; @@ -40,7 +40,7 @@ Classifier::Classifier(const std::string & filename, bool trainMode) if(fscanf(fd, "Oracle Filename : %s\n", buffer2) != 1) badFormatAndAbort(ERRINFO); - oracle = Oracle::getOracle(buffer, buffer2); + oracle = Oracle::getOracle(buffer, expPath + std::string("/") + buffer2); } else oracle = Oracle::getOracle(buffer); @@ -55,17 +55,12 @@ Classifier::Classifier(const std::string & filename, bool trainMode) if(fscanf(fd, "Feature Model : %s\n", buffer) != 1) badFormatAndAbort(ERRINFO); - fm.reset(new FeatureModel(buffer)); + fm.reset(new FeatureModel(expPath + buffer)); if(fscanf(fd, "Action Set : %s\n", buffer) != 1) badFormatAndAbort(ERRINFO); - as.reset(new ActionSet(buffer, false)); - - if(fscanf(fd, "Model : %s\n", buffer) != 1) - badFormatAndAbort(ERRINFO); - - modelFilename = buffer; + as.reset(new ActionSet(expPath + buffer, false)); } Classifier::Type Classifier::str2type(const std::string & s) @@ -120,7 +115,7 @@ void Classifier::initClassifier(Config & config) if(!trainMode) { - mlp.reset(new MLP(modelFilename)); + mlp.reset(new MLP(config.expPath + name + ".model")); return; } @@ -214,11 +209,6 @@ void Classifier::save(const std::string & filename) mlp->save(filename); } -void Classifier::save() -{ - mlp->save(modelFilename); -} - Action * Classifier::getAction(const std::string & name) { return as->getAction(name); diff --git a/tape_machine/src/Config.cpp b/tape_machine/src/Config.cpp index 93a11dd72fe8b9e5a2870d96b3ec3577d03e2532..d53a248b5e179190d4ca06b3f361718b7c1c300f 100644 --- a/tape_machine/src/Config.cpp +++ b/tape_machine/src/Config.cpp @@ -1,46 +1,47 @@ #include "Config.hpp" #include "File.hpp" -Config::Config(MCD & mcd) : mcd(mcd), tapes(mcd.getNbLines()) +Config::Config(BD & bd, const std::string & expPath) : bd(bd), tapes(bd.getNbLines()) { + this->expPath = expPath; this->currentStateName = nullptr; head = 0; for(unsigned int i = 0; i < tapes.size(); i++) { - tapes[i].name = mcd.getNameOfLine(i); - tapes[i].isKnown = mcd.lineIsKnown(i); + tapes[i].name = bd.getNameOfLine(i); + tapes[i].isKnown = bd.lineIsKnown(i); } } Config::Tape & Config::getTape(const std::string & name) { - return tapes[mcd.getLineOfName(name)]; + return tapes[bd.getLineOfName(name)]; } Config::Tape & Config::getTapeByInputCol(int col) { - return tapes[mcd.getLineOfInputCol(col)]; + return tapes[bd.getLineOfInputCol(col)]; } void Config::readInput(const std::string & filename) { this->inputFilename = filename; File file(filename, "r"); - int nbInputCol = mcd.getNbInputColumns(); + FILE * fd = file.getDescriptor(); - while(!file.isFinished()) + char buffer[10000]; + std::vector<std::string> cols; + + while(fscanf(fd, "%[^\n]\n", buffer) == 1) { - for (int col = 0; col < nbInputCol; col++) - { - auto & tape = getTapeByInputCol(col); - tape.ref.emplace_back(); - file.readUntil(isNotSeparator); - file.readUntil(tape.ref.back(), isSeparator); - file.readUntil(isNotSeparator); + cols = split(buffer); + for(unsigned int i = 0; i < cols.size(); i++) + if(bd.hasLineOfInputCol(i)) + { + auto & tape = getTapeByInputCol(i); - // Not necessary, it just add the value to the dictionary - mcd.getDictOfInputCol(col)->getValue(tape.ref.back()); - } + tape.ref.emplace_back(cols[i]); + } } // Making all tapes the same size @@ -129,13 +130,13 @@ void Config::printAsOutput(FILE * output) { unsigned int lastToPrint = 0; for (unsigned int j = 0; j < tapes.size(); j++) - if(mcd.mustPrintLine(j)) + if(bd.mustPrintLine(j)) lastToPrint = j; for (unsigned int i = 0; i < tapes[0].hyp.size(); i++) { for (unsigned int j = 0; j < tapes.size(); j++) - if(mcd.mustPrintLine(j)) + if(bd.mustPrintLine(j)) fprintf(output, "%s%s", tapes[j][i].empty() ? "0" : tapes[j][i].c_str(), j == lastToPrint ? "\n" : "\t"); } } diff --git a/tape_machine/src/FeatureBank.cpp b/tape_machine/src/FeatureBank.cpp index dd218daf9774b28c244a0a22651d1f8395da1bd6..3646100d471905daf762b5c2bd166af38d50f4d1 100644 --- a/tape_machine/src/FeatureBank.cpp +++ b/tape_machine/src/FeatureBank.cpp @@ -107,7 +107,7 @@ std::function<FeatureModel::FeatureValue(Config &)> FeatureBank::str2func(const FeatureModel::FeatureValue FeatureBank::actionHistory(Config & config, int index, const std::string & featName) { - Dict * dict = config.mcd.getDictOfLine("actions"); + Dict * dict = Dict::getDict("actions"); auto policy = dictPolicy2FeaturePolicy(dict->policy); auto & history = config.actionHistory[*config.currentStateName]; @@ -122,7 +122,7 @@ FeatureModel::FeatureValue FeatureBank::ldep(Config & config, int index, const s auto & tape = config.getTape(tapeName); auto & govs = config.getTape("GOV"); auto & eos = config.getTape("EOS"); - Dict * dict = config.mcd.getDictOfLine(tapeName); + Dict * dict = config.bd.getDictOfLine(tapeName); auto policy = dictPolicy2FeaturePolicy(dict->policy); if(object == "s") @@ -161,7 +161,7 @@ FeatureModel::FeatureValue FeatureBank::rdep(Config & config, int index, const s auto & tape = config.getTape(tapeName); auto & govs = config.getTape("GOV"); auto & eos = config.getTape("EOS"); - Dict * dict = config.mcd.getDictOfLine(tapeName); + Dict * dict = config.bd.getDictOfLine(tapeName); auto policy = dictPolicy2FeaturePolicy(dict->policy); if(object == "s") @@ -198,7 +198,7 @@ FeatureModel::FeatureValue FeatureBank::rdep(Config & config, int index, const s FeatureModel::FeatureValue FeatureBank::simpleBufferAccess(Config & config, int relativeIndex, const std::string & tapeName, const std::string & featName) { auto & tape = config.getTape(tapeName); - Dict * dict = config.mcd.getDictOfLine(tapeName); + Dict * dict = config.bd.getDictOfLine(tapeName); auto policy = dictPolicy2FeaturePolicy(dict->policy); int index = config.head + relativeIndex; @@ -215,7 +215,7 @@ FeatureModel::FeatureValue FeatureBank::simpleBufferAccess(Config & config, int FeatureModel::FeatureValue FeatureBank::simpleStackAccess(Config & config, int relativeIndex, const std::string & tapeName, const std::string & featName) { auto & tape = config.getTape(tapeName); - Dict * dict = config.mcd.getDictOfLine(tapeName); + Dict * dict = config.bd.getDictOfLine(tapeName); auto policy = dictPolicy2FeaturePolicy(dict->policy); if(relativeIndex < 0 || relativeIndex >= (int)config.stack.size()) @@ -232,9 +232,9 @@ FeatureModel::FeatureValue FeatureBank::simpleStackAccess(Config & config, int r return {dict, featName, &tape[index], dict->getValue(tape[index]), policy}; } -FeatureModel::FeatureValue FeatureBank::getUppercase(Config & config, const FeatureModel::FeatureValue & fv) +FeatureModel::FeatureValue FeatureBank::getUppercase(Config &, const FeatureModel::FeatureValue & fv) { - Dict * dict = config.mcd.getDictOfLine("boolean"); + Dict * dict = Dict::getDict("bool"); auto policy = dictPolicy2FeaturePolicy(dict->policy); bool firstLetterUppercase = isUpper((*fv.value)[0]); @@ -246,9 +246,9 @@ FeatureModel::FeatureValue FeatureBank::getUppercase(Config & config, const Feat return {dict, fv.name, str, dict->getValue(*str), policy}; } -FeatureModel::FeatureValue FeatureBank::getLength(Config & config, const FeatureModel::FeatureValue & fv) +FeatureModel::FeatureValue FeatureBank::getLength(Config &, const FeatureModel::FeatureValue & fv) { - Dict * dict = config.mcd.getDictOfLine("integer"); + Dict * dict = Dict::getDict("int"); auto policy = dictPolicy2FeaturePolicy(dict->policy); int len = lengthPrinted(*fv.value); @@ -264,9 +264,9 @@ FeatureModel::FeatureValue FeatureBank::getLength(Config & config, const Feature return {dict, fv.name, str, dict->getValue(*str), policy}; } -FeatureModel::FeatureValue FeatureBank::getLetters(Config & config, const FeatureModel::FeatureValue & fv, int from, int to) +FeatureModel::FeatureValue FeatureBank::getLetters(Config &, const FeatureModel::FeatureValue & fv, int from, int to) { - Dict * dict = config.mcd.getDictOfLine("letters"); + Dict * dict = Dict::getDict("letters"); auto policy = dictPolicy2FeaturePolicy(dict->policy); if(*fv.value == Dict::nullValueStr) diff --git a/tape_machine/src/MCD.cpp b/tape_machine/src/MCD.cpp deleted file mode 100644 index 9de2f413e722f02b02de007fa046ee32650b875d..0000000000000000000000000000000000000000 --- a/tape_machine/src/MCD.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "MCD.hpp" -#include "File.hpp" -#include "util.hpp" - -MCD::Line::Line(int num, std::string name, std::string dictFilename, - std::string dictPolicy, int inputColumn, bool mustPrint, bool isKnown) -{ - this->dict = Dict::getDict(Dict::str2policy(dictPolicy), dictFilename); - this->num = num; - this->name = name; - this->inputColumn = inputColumn; - this->mustPrint = mustPrint; - this->isKnown = isKnown; -} - -MCD::MCD(const std::string & filename) -{ - File file(filename, "r"); - FILE * fd = file.getDescriptor(); - - int num; - char b0[1024]; - char b1[1024]; - char b2[1024]; - char b3[1024]; - int inputColumn; - int mustPrint; - - nbInputCols = 0; - - while(fscanf(fd, "%d %s %s %s %s %d %d\n", &num, b1, b0, b2, b3, &inputColumn, &mustPrint) == 7) - { - if(inputColumn >= 0) - nbInputCols++; - - if(noAccentLower(b0) != std::string("known") && noAccentLower(b0) != std::string("unknown")) - { - fprintf(stderr, "ERROR (%s) : \'%s\' is not a valid MCD line argument. Aborting.\n", ERRINFO, b0); - exit(1); - } - - bool known = noAccentLower(b0) == std::string("known"); - - lines.emplace_back(new Line(num, b1, b2, b3, inputColumn, mustPrint == 1, known)); - Line * line = lines.back().get(); - num2line.emplace(line->num, line); - name2line.emplace(line->name, line); - col2line.emplace(line->inputColumn, line); - } -} - -Dict * MCD::getDictOfLine(int num) -{ - auto it = num2line.find(num); - - if(it == num2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line number %d in MCD. Aborting.\n", ERRINFO, num); - exit(1); - } - - return it->second->dict; -} - -Dict * MCD::getDictOfLine(const std::string & name) -{ - auto it = name2line.find(name); - - if(it == name2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line \'%s\' in MCD. Aborting.\n", ERRINFO, name.c_str()); - exit(1); - } - - return it->second->dict; -} - -Dict * MCD::getDictOfInputCol(int col) -{ - auto it = col2line.find(col); - - if(it == col2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line of input column %d in MCD. Aborting.\n", ERRINFO, col); - exit(1); - } - - return it->second->dict; -} - -int MCD::getLineOfName(const std::string & name) -{ - auto it = name2line.find(name); - - if(it == name2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line %s in MCD. Aborting.\n", ERRINFO, name.c_str()); - exit(1); - } - - return it->second->num; -} - -int MCD::getLineOfInputCol(int col) -{ - auto it = col2line.find(col); - - if(it == col2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line in MCD corresponding to input col %d. Aborting.\n", ERRINFO, col); - exit(1); - } - - return it->second->num; -} - -int MCD::getNbInputColumns() -{ - return nbInputCols; -} - -int MCD::getNbLines() -{ - return lines.size(); -} - -bool MCD::mustPrintLine(int index) -{ - auto it = num2line.find(index); - - if(it == num2line.end()) - { - fprintf(stderr, "ERROR (%s) : requesting line number %d in MCD. Aborting.\n", ERRINFO, index); - exit(1); - } - - return it->second->mustPrint; -} - -const std::string & MCD::getNameOfLine(int line) -{ - return lines[line]->name; -} - -bool MCD::lineIsKnown(int line) -{ - return lines[line]->isKnown; -} - diff --git a/tape_machine/src/TapeMachine.cpp b/tape_machine/src/TapeMachine.cpp index 1356183b4a5c30b7060aab873e446010cac11525..0639c6ba89a71155d8cfd68faddac30e28830886 100644 --- a/tape_machine/src/TapeMachine.cpp +++ b/tape_machine/src/TapeMachine.cpp @@ -3,7 +3,7 @@ #include "util.hpp" #include <cstring> -TapeMachine::TapeMachine(const std::string & filename, bool trainMode) +TapeMachine::TapeMachine(const std::string & filename, bool trainMode, const std::string & expPath) { auto badFormatAndAbort = [&filename](const std::string & errInfo) { @@ -27,6 +27,12 @@ TapeMachine::TapeMachine(const std::string & filename, bool trainMode) name = buffer; + // Reading dicts + if(fscanf(fd, "Dicts : %[^\n]\n", buffer) != 1) + badFormatAndAbort(ERRINFO); + + Dict::readDicts(expPath + buffer, trainMode, expPath); + // Reading %CLASSIFIERS if(fscanf(fd, "%%%s\n", buffer) != 1 || buffer != std::string("CLASSIFIERS")) badFormatAndAbort(ERRINFO); @@ -37,7 +43,7 @@ TapeMachine::TapeMachine(const std::string & filename, bool trainMode) if(fscanf(fd, "%s %s\n", buffer, buffer2) != 2) badFormatAndAbort(ERRINFO); - str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2, trainMode))); + str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2, trainMode, expPath))); classifiers.emplace_back(str2classifier[buffer].get()); } diff --git a/tests/src/test_decode.cpp b/tests/src/test_decode.cpp index 2fc5c54c3e66995c8bea0cb1d7162637b8d6cb2b..2d1387792fb859de8a94e536153a2f5ba962914d 100644 --- a/tests/src/test_decode.cpp +++ b/tests/src/test_decode.cpp @@ -1,7 +1,7 @@ #include <cstdio> #include <cstdlib> #include <boost/program_options.hpp> -#include "MCD.hpp" +#include "BD.hpp" #include "Config.hpp" #include "TapeMachine.hpp" #include "Decoder.hpp" @@ -14,16 +14,23 @@ po::options_description getOptionsDescription() po::options_description req("Required"); req.add_options() + ("expName", po::value<std::string>()->required(), + "Name of this experiment") ("tm", po::value<std::string>()->required(), "File describing the Tape Machine to use") + ("bd", po::value<std::string>()->required(), + "BD file that describes the multi-tapes buffer") ("mcd", po::value<std::string>()->required(), "MCD file that describes the input") ("input,I", po::value<std::string>()->required(), - "Input file formated according to the MCD"); + "Input file formated according to the mcd"); po::options_description opt("Optional"); opt.add_options() - ("help,h", "Produce this help message"); + ("help,h", "Produce this help message") + ("lang", po::value<std::string>()->default_value("fr"), + "Language you are working with"); + desc.add(req).add(opt); @@ -65,17 +72,28 @@ int main(int argc, char * argv[]) po::variables_map vm = checkOptions(od, argc, argv); - std::string mcdFilename = vm["mcd"].as<std::string>(); std::string tmFilename = vm["tm"].as<std::string>(); + std::string bdFilename = vm["bd"].as<std::string>(); + std::string mcdFilename = vm["mcd"].as<std::string>(); std::string inputFilename = vm["input"].as<std::string>(); + std::string lang = vm["lang"].as<std::string>(); + std::string expName = vm["expName"].as<std::string>(); + + const char * MACAON_DIR = std::getenv("MACAON_DIR"); + std::string slash = "/"; + std::string expPath = MACAON_DIR + slash + lang + slash + "bin/" + expName + slash; + + tmFilename = expPath + tmFilename; + bdFilename = expPath + bdFilename; + mcdFilename = expPath + mcdFilename; - TapeMachine tapeMachine(tmFilename, false); + TapeMachine tapeMachine(tmFilename, false, expPath); - MCD mcd(mcdFilename); - Config config(mcd); + BD bd(bdFilename, mcdFilename, expPath); + Config config(bd, expPath); config.readInput(inputFilename); - Decoder decoder(tapeMachine, mcd, config); + Decoder decoder(tapeMachine, bd, config); decoder.decode(); diff --git a/tests/src/test_train.cpp b/tests/src/test_train.cpp index ee05c19db4499b0016ce7daea21107bb855f2f29..9f367e4ba8b8e0e7028a156aafabd4f21299709a 100644 --- a/tests/src/test_train.cpp +++ b/tests/src/test_train.cpp @@ -1,7 +1,7 @@ #include <cstdio> #include <cstdlib> #include <boost/program_options.hpp> -#include "MCD.hpp" +#include "BD.hpp" #include "Config.hpp" #include "TapeMachine.hpp" #include "Trainer.hpp" @@ -14,8 +14,12 @@ po::options_description getOptionsDescription() po::options_description req("Required"); req.add_options() + ("expName", po::value<std::string>()->required(), + "Name of this experiment") ("tm", po::value<std::string>()->required(), "File describing the Tape Machine we will train") + ("bd", po::value<std::string>()->required(), + "BD file that describes the multi-tapes buffer") ("mcd", po::value<std::string>()->required(), "MCD file that describes the input") ("train,T", po::value<std::string>()->required(), @@ -24,10 +28,10 @@ po::options_description getOptionsDescription() po::options_description opt("Optional"); opt.add_options() ("help,h", "Produce this help message") - ("devmcd", po::value<std::string>()->default_value(""), - "MCD file that describes the input") ("dev", po::value<std::string>()->default_value(""), "Development corpus formated according to the MCD") + ("lang", po::value<std::string>()->default_value("fr"), + "Language you are working with") ("nbiter,n", po::value<int>()->default_value(5), "Number of training epochs (iterations)") ("batchsize,b", po::value<int>()->default_value(256), @@ -75,38 +79,51 @@ int main(int argc, char * argv[]) po::variables_map vm = checkOptions(od, argc, argv); - std::string trainMCDfilename = vm["mcd"].as<std::string>(); - std::string devMCDfilename = vm["devmcd"].as<std::string>(); + std::string BDfilename = vm["bd"].as<std::string>(); + std::string MCDfilename = vm["mcd"].as<std::string>(); std::string tmFilename = vm["tm"].as<std::string>(); std::string trainFilename = vm["train"].as<std::string>(); std::string devFilename = vm["dev"].as<std::string>(); + std::string expName = vm["expName"].as<std::string>(); + std::string lang = vm["lang"].as<std::string>(); int nbIter = vm["nbiter"].as<int>(); int batchSize = vm["batchsize"].as<int>(); bool mustShuffle = vm["shuffle"].as<bool>(); - TapeMachine tapeMachine(tmFilename, true); + const char * MACAON_DIR = std::getenv("MACAON_DIR"); + std::string slash = "/"; + std::string expPath = MACAON_DIR + slash + lang + slash + "bin/" + expName + slash; + + BDfilename = expPath + BDfilename; + MCDfilename = expPath + MCDfilename; + tmFilename = expPath + tmFilename; + trainFilename = expPath + trainFilename; + devFilename = expPath + devFilename; + + TapeMachine tapeMachine(tmFilename, true, expPath); - MCD trainMcd(trainMCDfilename); - Config trainConfig(trainMcd); + BD trainBD(BDfilename, MCDfilename, expPath); + Config trainConfig(trainBD, expPath); trainConfig.readInput(trainFilename); - std::unique_ptr<MCD> devMcd; + std::unique_ptr<BD> devBD; std::unique_ptr<Config> devConfig; std::unique_ptr<Trainer> trainer; - if(devFilename.empty() || devMCDfilename.empty()) + if(devFilename.empty()) { - trainer.reset(new Trainer(tapeMachine, trainMcd, trainConfig)); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig)); } else { - devMcd.reset(new MCD(devMCDfilename)); - devConfig.reset(new Config(*devMcd.get())); + devBD.reset(new BD(BDfilename, MCDfilename, expPath)); + devConfig.reset(new Config(*devBD.get(), expPath)); devConfig->readInput(devFilename); - trainer.reset(new Trainer(tapeMachine, trainMcd, trainConfig, devMcd.get(), devConfig.get())); + trainer.reset(new Trainer(tapeMachine, trainBD, trainConfig, devBD.get(), devConfig.get())); } + trainer->expPath = expPath; trainer->train(nbIter, batchSize, mustShuffle); return 0; diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index bf1b1c757a78419ee35470e1e397262bdcefc19d..f3d202aa589dbd4ef5187b3f741406db40455c33 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -2,18 +2,22 @@ #define TRAINER__H #include "TapeMachine.hpp" -#include "MCD.hpp" +#include "BD.hpp" #include "Config.hpp" class Trainer { + public : + + std::string expPath; + private : TapeMachine & tm; - MCD & trainMcd; + BD & trainBD; Config & trainConfig; - MCD * devMcd; + BD * devBD; Config * devConfig; public : @@ -44,8 +48,8 @@ class Trainer public : - Trainer(TapeMachine & tm, MCD & mcd, Config & config); - Trainer(TapeMachine & tm, MCD & mcd, Config & config, MCD * devMcd, Config * devConfig); + Trainer(TapeMachine & tm, BD & bd, Config & config); + Trainer(TapeMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig); void train(int nbIter, int batchSize, bool mustShuffle); }; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 088c306675cb3d49c5dfaa73e7934747599d80f0..91bffcddb6b5408b36e511cd386ea25a3430d1ef 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,14 +1,14 @@ #include "Trainer.hpp" #include "util.hpp" -Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config) -: tm(tm), trainMcd(mcd), trainConfig(config) +Trainer::Trainer(TapeMachine & tm, BD & bd, Config & config) +: tm(tm), trainBD(bd), trainConfig(config) { - this->devMcd = nullptr; + this->devBD = nullptr; this->devConfig = nullptr; } -Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config, MCD * devMcd, Config * devConfig) : tm(tm), trainMcd(mcd), trainConfig(config), devMcd(devMcd), devConfig(devConfig) +Trainer::Trainer(TapeMachine & tm, BD & bd, Config & config, BD * devBD, Config * devConfig) : tm(tm), trainBD(bd), trainConfig(config), devBD(devBD), devConfig(devConfig) { } @@ -137,7 +137,7 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) getExamplesByClassifier(trainExamples, trainConfig); - if(devMcd && devConfig) + if(devBD && devConfig) getExamplesByClassifier(devExamples, *devConfig); auto & classifiers = tm.getClassifiers(); @@ -175,7 +175,10 @@ void Trainer::trainBatched(int nbIter, int batchSize, bool mustShuffle) for(Classifier * cla : classifiers) if(cla->needsTrain()) if(bestIter[cla->name] == i) - cla->save(); + { + cla->save(expPath + cla->name + ".model"); + Dict::saveDicts(expPath); + } } }