From 78a246d70252dec64580e1b3e0df8511e0b7fbb2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 13 Jan 2020 15:13:02 +0100 Subject: [PATCH] Working skeleton of oracle decoding --- dev/src/dev.cpp | 46 +++++++++++++++++----- reading_machine/include/Classifier.hpp | 1 + reading_machine/include/ReadingMachine.hpp | 2 + reading_machine/include/Strategy.hpp | 2 + reading_machine/include/Transition.hpp | 1 + reading_machine/include/TransitionSet.hpp | 4 +- reading_machine/src/Classifier.cpp | 5 +++ reading_machine/src/Config.cpp | 9 +++-- reading_machine/src/ReadingMachine.cpp | 10 +++++ reading_machine/src/Strategy.cpp | 7 ++++ reading_machine/src/Transition.cpp | 5 +++ reading_machine/src/TransitionSet.cpp | 36 ++++++++++++++--- 12 files changed, 107 insertions(+), 21 deletions(-) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index d33d370..c538b98 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -8,23 +8,49 @@ int main(int argc, char * argv[]) { - /* - BaseConfig goldConfig(argv[3], argv[1], argv[2]); + if (argc != 5) + { + fmt::print(stderr, "needs 4 arguments.\n"); + exit(1); + } + + std::string machineFile = argv[1]; + std::string mcdFile = argv[2]; + std::string tsvFile = argv[3]; + //std::string rawFile = argv[4]; + std::string rawFile = ""; + + ReadingMachine machine(machineFile); + BaseConfig goldConfig(mcdFile, tsvFile, rawFile); SubConfig config(goldConfig); - auto other = config; - while (config.moveWordIndex(1)) + config.setState(machine.getStrategy().getInitialState()); + + config.printForDebug(stderr); + + while (true) { + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); + if (!transition) + util::myThrow("No transition appliable !"); + + fmt::print(stderr, "Transition : {}\n", transition->getName()); + transition->apply(config); + + auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + if (!config.moveWordIndex(movement.second)) + util::myThrow("Cannot move word index !"); + if (config.needsUpdate()) config.update(); - } - - fmt::print(stderr, "ok\n"); - std::scanf("%*c"); - */ - ReadingMachine machine(argv[1]); + config.printForDebug(stderr); + } return 0; } diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index ce61d5a..5d38ae8 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -16,6 +16,7 @@ class Classifier public : Classifier(const std::string & name, const std::string & topology, const std::string & tsFile); + TransitionSet & getTransitionSet(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index dc4be73..ab5d794 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -16,6 +16,8 @@ class ReadingMachine public : ReadingMachine(const std::string & filename); + TransitionSet & getTransitionSet(); + Strategy & getStrategy(); }; #endif diff --git a/reading_machine/include/Strategy.hpp b/reading_machine/include/Strategy.hpp index e1bc580..10d9628 100644 --- a/reading_machine/include/Strategy.hpp +++ b/reading_machine/include/Strategy.hpp @@ -21,6 +21,7 @@ class Strategy std::map<std::pair<std::string, std::string>, std::string> edges; std::map<std::string, bool> isDone; std::vector<std::string> defaultCycle; + std::string initialState{"UNDEFINED"}; private : @@ -31,6 +32,7 @@ class Strategy Strategy(const std::vector<std::string_view> & lines); std::pair<std::string, int> getMovement(const Config & c, const std::string & transition); + const std::string getInitialState() const; }; #endif diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index ecd90df..cbdf21f 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -23,6 +23,7 @@ class Transition void apply(Config & config); bool appliable(const Config & config) const; int getCost(const Config & config) const; + const std::string & getName() const; }; #endif diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 04c1e90..188ce70 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -11,13 +11,13 @@ class TransitionSet private : std::vector<Transition> transitions; - std::unordered_map<std::string, std::size_t> name2index; std::optional<std::size_t> defaultAction; public : TransitionSet(const std::string & filename); - std::vector<std::pair<Transition &, int>> getAppliableTransitionsCosts(const Config & c); + std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); + Transition * getBestAppliableTransition(const Config & c); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 47100c7..d446be2 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -7,3 +7,8 @@ Classifier::Classifier(const std::string & name, const std::string & topology, c this->nn = MLP(topology); } +TransitionSet & Classifier::getTransitionSet() +{ + return *transitionSet; +} + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 1a29fff..f4d1001 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -80,7 +80,9 @@ void Config::printForDebug(FILE * dest) const { static constexpr int windowSize = 5; static constexpr int lettersWindowSize = 40; - static constexpr int maxWordLength = 10; + static constexpr int maxWordLength = 7; + + fmt::print(dest, "\n"); int firstLineToPrint = wordIndex; int lastLineToPrint = wordIndex; @@ -138,7 +140,8 @@ void Config::printForDebug(FILE * dest) const fmt::print(dest, "{}\n", longLine); for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++) fmt::print(dest, "{}", getLetter(index)); - fmt::print(dest, "\n{}\n", longLine); + if (rawInput.size()) + fmt::print(dest, "\n{}\n", longLine); fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr); fmt::print(dest, "{}\n", longLine); @@ -151,8 +154,6 @@ void Config::printForDebug(FILE * dest) const if (toPrint[line].back() == EOSSymbol1) fmt::print(dest, "\n"); } - - fmt::print(dest, "\n"); } Config::String & Config::getLastNotEmpty(int colIndex, int lineIndex) diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index cc484eb..f2d49be 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -34,3 +34,13 @@ ReadingMachine::ReadingMachine(const std::string & filename) } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} } +TransitionSet & ReadingMachine::getTransitionSet() +{ + return classifier->getTransitionSet(); +} + +Strategy & ReadingMachine::getStrategy() +{ + return *strategy; +} + diff --git a/reading_machine/src/Strategy.cpp b/reading_machine/src/Strategy.cpp index 3b49cd4..ca045f1 100644 --- a/reading_machine/src/Strategy.cpp +++ b/reading_machine/src/Strategy.cpp @@ -16,6 +16,8 @@ Strategy::Strategy(const std::vector<std::string_view> & lines) { key = std::pair<std::string,std::string>(splited[0], ""); value = splited[1]; + if (defaultCycle.empty()) + initialState = splited[0]; defaultCycle.emplace_back(value); } else if (splited.size() == 3) @@ -100,3 +102,8 @@ std::pair<std::string, int> Strategy::getMovementIncremental(const Config & c, c return {defaultCycle.back(), 1}; } +const std::string Strategy::getInitialState() const +{ + return initialState; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index d8b9c8a..46147d4 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -61,3 +61,8 @@ void Transition::initWrite(std::string colName, std::string object, std::string }; } +const std::string & Transition::getName() const +{ + return name; +} + diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 467ad42..a1a1f89 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -1,4 +1,5 @@ #include "TransitionSet.hpp" +#include <limits> TransitionSet::TransitionSet(const std::string & filename) { @@ -23,14 +24,14 @@ TransitionSet::TransitionSet(const std::string & filename) std::fclose(file); } -std::vector<std::pair<Transition &, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c) +std::vector<std::pair<Transition*, int>> TransitionSet::getAppliableTransitionsCosts(const Config & c) { - using Pair = std::pair<Transition &, int>; + using Pair = std::pair<Transition*, int>; std::vector<Pair> appliableTransitions; - for (auto & transition : transitions) - if (transition.appliable(c)) - appliableTransitions.emplace_back(transition, transition.getCost(c)); + for (unsigned int i = 0; i < transitions.size(); i++) + if (transitions[i].appliable(c)) + appliableTransitions.emplace_back(&transitions[i], transitions[i].getCost(c)); std::sort(appliableTransitions.begin(), appliableTransitions.end(), [](const Pair & a, const Pair & b) @@ -41,3 +42,28 @@ std::vector<std::pair<Transition &, int>> TransitionSet::getAppliableTransitions return appliableTransitions; } +Transition * TransitionSet::getBestAppliableTransition(const Config & c) +{ + Transition * result = nullptr; + int bestCost = std::numeric_limits<int>::max(); + + for (unsigned int i = 0; i < transitions.size(); i++) + { + if (!transitions[i].appliable(c)) + continue; + + int cost = transitions[i].getCost(c); + + if (cost == 0) + return &transitions[i]; + + if (cost < bestCost) + { + result = &transitions[i]; + bestCost = cost; + } + } + + return result; +} + -- GitLab