From bc2ede62673fecb114423419af163c225283a337 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 12 Feb 2020 15:13:41 +0100 Subject: [PATCH] macaon_decode is working --- common/include/util.hpp | 3 ++ common/src/Dict.cpp | 10 ++--- common/src/util.cpp | 15 +++++++ decoder/src/Decoder.cpp | 3 ++ decoder/src/macaon_decode.cpp | 24 ++++++++--- reading_machine/include/Classifier.hpp | 1 + reading_machine/include/ReadingMachine.hpp | 14 +++++-- reading_machine/src/Classifier.cpp | 5 +++ reading_machine/src/ReadingMachine.cpp | 47 ++++++++++++++++++---- trainer/src/macaon_train.cpp | 11 ++++- 10 files changed, 109 insertions(+), 24 deletions(-) diff --git a/common/include/util.hpp b/common/include/util.hpp index 6b2d077..efe509a 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -19,6 +19,7 @@ #include <array> #include <unordered_map> #include <regex> +#include <filesystem> #include <experimental/source_location> #include <boost/flyweight.hpp> #include "fmt/core.h" @@ -33,6 +34,8 @@ void error(std::string_view message, const std::experimental::source_location & void error(const std::exception & e, const std::experimental::source_location & location = std::experimental::source_location::current()); void myThrow(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); +std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension); + std::string_view getFilenameFromPath(std::string_view s); std::vector<std::string_view> split(std::string_view s, char delimiter); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index a02edf3..74eac88 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -19,11 +19,11 @@ void Dict::readFromFile(const char * filename) std::FILE * file = std::fopen(filename, "r"); if (!file) - util::myThrow(fmt::format("could not open file \'%s\'", filename)); + util::myThrow(fmt::format("could not open file '{}'", filename)); char buffer[1048]; if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1) - util::myThrow(fmt::format("file \'%s\' bad format", filename)); + util::myThrow(fmt::format("file '{}' bad format", filename)); Encoding encoding{Encoding::Ascii}; if (std::string(buffer) == "Ascii") @@ -31,12 +31,12 @@ void Dict::readFromFile(const char * filename) else if (std::string(buffer) == "Binary") encoding = Encoding::Binary; else - util::myThrow(fmt::format("file \'%s\' bad format", filename)); + util::myThrow(fmt::format("file '{}' bad format", filename)); int nbEntries; if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1) - util::myThrow(fmt::format("file \'%s\' bad format", filename)); + util::myThrow(fmt::format("file '{}' bad format", filename)); elementsToIndexes.reserve(nbEntries); @@ -45,7 +45,7 @@ void Dict::readFromFile(const char * filename) for (int i = 0; i < nbEntries; i++) { if (!readEntry(file, &entryIndex, entryString, encoding)) - util::myThrow(fmt::format("file \'%s\' line {} bad format", filename, i)); + util::myThrow(fmt::format("file '{}' line {} bad format", filename, i)); elementsToIndexes[entryString] = entryIndex; } diff --git a/common/src/util.cpp b/common/src/util.cpp index 379056f..13d4122 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -171,3 +171,18 @@ std::string util::strip(const std::string & s) return std::string(s.begin()+first, s.begin()+last+1); } +std::vector<std::filesystem::path> util::findFilesByExtension(std::filesystem::path directory, std::string extension) +{ + std::vector<std::filesystem::path> files; + + for (auto entry : std::filesystem::directory_iterator(directory)) + if (entry.is_regular_file()) + { + auto path = entry.path(); + if (path.extension() == extension) + files.push_back(path); + } + + return files; +} + diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index eea1ef7..664e66e 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -7,6 +7,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) void Decoder::decode(BaseConfig & config, std::size_t beamSize) { + try + { config.setState(machine.getStrategy().getInitialState()); fmt::print(stderr, "\r{:80}\rDecoding dev...", " "); @@ -42,6 +44,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) if (!config.moveWordIndex(movement.second)) util::myThrow("Cannot move word index !"); } + } catch(std::exception & e) {util::myThrow(e.what());} } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index c3c3cbf..8897d40 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -64,19 +64,31 @@ int main(int argc, char * argv[]) auto variables = checkOptions(od, argc, argv); std::filesystem::path modelPath(variables["model"].as<std::string>()); - auto machinePath = modelPath / ReadingMachine::defaultMachineName; + auto machinePath = modelPath / ReadingMachine::defaultMachineFilename; + auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, "")); + auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; auto mcdFile = variables["mcd"].as<std::string>(); - ReadingMachine machine(machinePath.string()); - Decoder decoder(machine); + if (dictPaths.empty()) + util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, ""))); + if (modelPaths.empty()) + util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, ""))); - BaseConfig config(mcdFile, inputTSV, inputTXT); + try + { + ReadingMachine machine(machinePath, modelPaths, dictPaths); + Decoder decoder(machine); + + BaseConfig config(mcdFile, inputTSV, inputTXT); - decoder.decode(config, 1); + decoder.decode(config, 1); - config.print(stdout); + fmt::print(stderr, "\n"); + + config.print(stdout); + } catch(std::exception & e) {util::error(e);} return 0; } diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 0e8b120..35f0611 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -18,6 +18,7 @@ class Classifier Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); TransitionSet & getTransitionSet(); TestNetwork & getNN(); + const std::string & getName() const; }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 8db2de1..1c08bd8 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -12,9 +12,10 @@ class ReadingMachine { public : - static inline const std::string defaultMachineName = "machine.rm"; - static inline const std::string defaultModelName = "{}.pt"; - static inline const std::string defaultDictName = "{}.dict"; + static inline const std::string defaultMachineFilename = "machine.rm"; + static inline const std::string defaultModelFilename = "{}.pt"; + static inline const std::string defaultDictFilename = "{}.dict"; + static inline const std::string defaultDictName = "_default_"; private : @@ -25,14 +26,19 @@ class ReadingMachine std::unique_ptr<FeatureFunction> featureFunction; std::map<std::string, Dict> dicts; + private : + + void readFromFile(std::filesystem::path path); + public : ReadingMachine(std::filesystem::path path); - ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts); + ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts); TransitionSet & getTransitionSet(); Strategy & getStrategy(); Dict & getDict(const std::string & state); Classifier * getClassifier(); + void save(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 34c5f60..13e25b4 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -17,3 +17,8 @@ TestNetwork & Classifier::getNN() return nn; } +const std::string & Classifier::getName() const +{ + return name; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index d1d9da6..c9d5f6d 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -3,8 +3,23 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) { - dicts.emplace(std::make_pair("", Dict::State::Open)); + dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); + readFromFile(path); +} + +ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) +{ + readFromFile(path); + + for (auto path : dicts) + this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed}); + + torch::load(classifier->getNN(), models[0]); +} + +void ReadingMachine::readFromFile(std::filesystem::path path) +{ std::FILE * file = std::fopen(path.c_str(), "r"); char buffer[1024]; @@ -49,11 +64,6 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } -ReadingMachine::ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts) -{ - -} - TransitionSet & ReadingMachine::getTransitionSet() { return classifier->getTransitionSet(); @@ -68,8 +78,11 @@ Dict & ReadingMachine::getDict(const std::string & state) { auto found = dicts.find(state); - if (found == dicts.end()) - return dicts.at(""); + try + { + if (found == dicts.end()) + return dicts.at(defaultDictName); + } catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));} return found->second; } @@ -79,3 +92,21 @@ Classifier * ReadingMachine::getClassifier() return classifier.get(); } +void ReadingMachine::save() +{ + for (auto & it : dicts) + { + auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first); + std::FILE * file = std::fopen(pathToDict.c_str(), "w"); + if (!file) + util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str())); + + it.second.save(file, Dict::Encoding::Ascii); + + std::fclose(file); + } + + auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName()); + torch::save(classifier->getNN(), pathToClassifier); +} + diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index 896b9ae..d58127d 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -82,13 +82,22 @@ int main(int argc, char * argv[]) Decoder decoder(machine); BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); + float bestDevScore = 0; + for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(); auto devConfig = devGoldConfig; decoder.decode(devConfig, 1); decoder.evaluate(devConfig, modelPath, devTsvFile); - fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS")); + float devScore = decoder.getF1Score("UPOS"); + bool saved = devScore > bestDevScore; + if (saved) + { + bestDevScore = devScore; + machine.save(); + } + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : ""); } return 0; -- GitLab