diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ad0a5a87c20ce6e0acdb1b72ba35f112b2ead69..302d7f329d923d1f374f316ac2c5e8f4d27161ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ include_directories(common/include) include_directories(reading_machine/include) include_directories(torch_modules/include) include_directories(trainer/include) +include_directories(decoder/include) include_directories(utf8) add_subdirectory(fmt) @@ -35,4 +36,5 @@ add_subdirectory(dev) add_subdirectory(reading_machine) add_subdirectory(torch_modules) add_subdirectory(trainer) +add_subdirectory(decoder) diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp index d976573f3c3c383acfb1e769a8ec84277ed43861..d87df12ae9634dafe076116266c3859f34b5c11d 100644 --- a/common/include/Dict.hpp +++ b/common/include/Dict.hpp @@ -36,6 +36,7 @@ class Dict void insert(const std::string & element); int getIndexOrInsert(const std::string & element); void setState(State state); + State getState(); void save(std::FILE * destination, Encoding encoding); bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding); diff --git a/common/src/Dict.cpp b/common/src/Dict.cpp index fcfc942dd73e38140bcc18ae341bea9071a8c7c8..a02edf34487b6598a434946eb3b570659fb046e3 100644 --- a/common/src/Dict.cpp +++ b/common/src/Dict.cpp @@ -79,6 +79,11 @@ void Dict::setState(State state) this->state = state; } +Dict::State Dict::getState() +{ + return state; +} + void Dict::save(std::FILE * destination, Encoding encoding) { fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary"); diff --git a/decoder/CMakeLists.txt b/decoder/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f77044dc910294ee531b4db1698b2044947ea522 --- /dev/null +++ b/decoder/CMakeLists.txt @@ -0,0 +1,5 @@ +FILE(GLOB SOURCES src/*.cpp) + +add_library(decoder STATIC ${SOURCES}) +target_link_libraries(decoder reading_machine) + diff --git a/decoder/include/Decoder.hpp b/decoder/include/Decoder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90f3f205eb73c4b4d4517872c9daf41fa736bebc --- /dev/null +++ b/decoder/include/Decoder.hpp @@ -0,0 +1,27 @@ +#ifndef DECODER__H +#define DECODER__H + +#include <filesystem> +#include "ReadingMachine.hpp" +#include "SubConfig.hpp" + +class Decoder +{ + private : + + ReadingMachine & machine; + std::map<std::string, std::array<float,4>> evaluation; + + public : + + Decoder(ReadingMachine & machine); + void decode(BaseConfig & config, std::size_t beamSize); + void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV); + float getMetricScore(const std::string & metric, std::size_t scoreIndex); + float getPrecision(const std::string & metric); + float getF1Score(const std::string & metric); + float getRecall(const std::string & metric); + float getAlignedAcc(const std::string & metric); +}; + +#endif diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp new file mode 100644 index 0000000000000000000000000000000000000000..86e8044780b021c4d9c93240d9f1f7d178799862 --- /dev/null +++ b/decoder/src/Decoder.cpp @@ -0,0 +1,102 @@ +#include "Decoder.hpp" +#include "SubConfig.hpp" + +Decoder::Decoder(ReadingMachine & machine) : machine(machine) +{ +} + +void Decoder::decode(BaseConfig & config, std::size_t beamSize) +{ + config.setState(machine.getStrategy().getInitialState()); + + while (true) + { + auto dictState = machine.getDict(config.getState()).getState(); + auto context = config.extractContext(5,5,machine.getDict(config.getState())); + machine.getDict(config.getState()).setState(dictState); + + //TODO : check if clone is mandatory + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(); + //TODO : check if NoGradGuard does anything + torch::NoGradGuard guard; + auto prediction = machine.getClassifier()->getNN()(neuralInput); + + int chosenTransition = -1; + + for (unsigned int i = 0; i < prediction.size(0); i++) + if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i)) + chosenTransition = i; + + if (chosenTransition == -1) + util::myThrow("No transition appliable !"); + + auto * transition = machine.getTransitionSet().getTransition(chosenTransition); + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + if (!config.moveWordIndex(movement.second)) + util::myThrow("Cannot move word index !"); + } +} + +float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) +{ + auto found = evaluation.find(metric); + + if (found == evaluation.end()) + util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : "")); + + return found->second[scoreIndex]; +} + +float Decoder::getPrecision(const std::string & metric) +{ + return getMetricScore(metric, 0); +} + +float Decoder::getRecall(const std::string & metric) +{ + return getMetricScore(metric, 1); +} + +float Decoder::getF1Score(const std::string & metric) +{ + return getMetricScore(metric, 2); +} + +float Decoder::getAlignedAcc(const std::string & metric) +{ + return getMetricScore(metric, 3); +} + +void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV) +{ + evaluation.clear(); + auto predictedTSV = (modelPath/"predicted_dev.tsv").string(); + std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w"); + config.print(predictedTSVFile); + std::fclose(predictedTSVFile); + + std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r"); + + char buffer[1024]; + while (!std::feof(evalFromUD)) + { + if (buffer != std::fgets(buffer, 1024, evalFromUD)) + break; + if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this](auto sm) + { + for (unsigned int i = 0; i < this->evaluation[sm[1]].size(); i++) + this->evaluation[sm[1]][i] = std::stof(sm[1+i]); + })){} + } + + pclose(evalFromUD); +} + diff --git a/dev/CMakeLists.txt b/dev/CMakeLists.txt index 35eee29ea2857b73d66cdf70a2ed2f435700aaba..a83003c045480414c4eb5a7dc5cd508dd02a365b 100644 --- a/dev/CMakeLists.txt +++ b/dev/CMakeLists.txt @@ -5,3 +5,4 @@ target_link_libraries(dev common) target_link_libraries(dev reading_machine) target_link_libraries(dev torch_modules) target_link_libraries(dev trainer) +target_link_libraries(dev decoder) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 219bacb258ec7168bef3549278ea9779346c0475..0ae5890b3243e7f4b9c182249d1aea0f33b2b142 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -1,4 +1,4 @@ -#include <cstdio> +#include <filesystem> #include "fmt/core.h" #include "util.hpp" #include "BaseConfig.hpp" @@ -6,6 +6,7 @@ #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "Trainer.hpp" +#include "Decoder.hpp" int main(int argc, char * argv[]) { @@ -15,13 +16,16 @@ int main(int argc, char * argv[]) exit(1); } - std::string machineFile = argv[1]; + std::string model = argv[1]; std::string mcdFile = argv[2]; std::string tsvFile = argv[3]; //std::string rawFile = argv[4]; std::string rawFile = ""; - ReadingMachine machine(machineFile); + std::filesystem::path modelPath(model); + auto machinePath = modelPath / "machine.rm"; + + ReadingMachine machine(machinePath.string()); BaseConfig goldConfig(mcdFile, tsvFile, rawFile); SubConfig config(goldConfig); @@ -29,13 +33,17 @@ int main(int argc, char * argv[]) Trainer trainer(machine); trainer.createDataset(config); - int nbEpoch = 5; + Decoder decoder(machine); + + int nbEpoch = 1; for (int i = 0; i < nbEpoch; i++) { float loss = trainer.epoch(); - fmt::print("\r{:80}", " "); - fmt::print("\rEpoch {}/{} loss = {}\n", i+1, nbEpoch, loss); + auto devConfig = goldConfig; + decoder.decode(devConfig, 1); + decoder.evaluate(devConfig, modelPath, tsvFile); + fmt::print(stderr, "\r{:80}\rEpoch {}/{} loss = {} dev = {}\n", " ", i+1, nbEpoch, loss, decoder.getF1Score("UPOS")); } return 0; diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 2b17c7a22358fb87256a0c65e572fd1aa13a0d21..df9551c1a36b36f5dbfce93b115c078de6e69166 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -19,6 +19,7 @@ class TransitionSet std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); Transition * getBestAppliableTransition(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; + Transition * getTransition(std::size_t index); std::size_t size() const; }; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 6de5017f2583bf8a55bf8ad32166c986195c6c67..7391396dd9b9a99542b4ccfa8fc38da2d2290c43 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -97,7 +97,7 @@ Action Action::addHypothesisRelative(const std::string & colName, Object object, else lineIndex = config.getStack(relativeIndex); - return addHypothesis(colName, lineIndex, "").apply(config, a); + return addHypothesis(colName, lineIndex, hypothesis).apply(config, a); }; auto undo = [colName, object, relativeIndex](Config & config, Action & a) diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 92bd4d0444d1a9f9a5d332fd2048530732058ee2..e7878ac601a788462e06db513a089e15f4552c66 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -80,3 +80,8 @@ std::size_t TransitionSet::getTransitionIndex(const Transition * transition) con return transition - &transitions[0]; } +Transition * TransitionSet::getTransition(std::size_t index) +{ + return &transitions[index]; +} + diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index 63257082bdd3dd0d51d29063726325fd862c4e1c..19debf8225cdb85af503c07714a839776d0ad480 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -29,10 +29,12 @@ torch::Tensor TestNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings} auto wordsAsEmb = wordEmbeddings(input); + auto reshaped = wordsAsEmb; // reshaped dim = {sequence, batch, embeddings} - auto reshaped = wordsAsEmb.permute({1,0,2}); + if (reshaped.dim() == 3) + reshaped = wordsAsEmb.permute({1,0,2}); - auto res = torch::softmax(linear(reshaped[focusedIndex]), 1); + auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0); return res; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6279f8e1524cc2d125be3e7a4d726877abea0654..65fb9b98e061d22107b46352375977183dea70c5 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -18,6 +18,7 @@ void Trainer::createDataset(SubConfig & config) if (!transition) util::myThrow("No transition appliable !"); + //TODO : check if clone is mandatory auto context = config.extractContext(5,5,machine.getDict(config.getState())); contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); @@ -82,8 +83,7 @@ float Trainer::epoch() if (nbExamplesUntilPrint <= 0) { nbExamplesUntilPrint = printInterval; - fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); - std::fflush(stdout); + fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar); lossSoFar = 0; } }