From 0e82071312a9c6984fb681e1a67aa4dd54cff495 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 24 Apr 2019 14:36:43 +0200 Subject: [PATCH] Prepared for classifier of error detection --- decoder/src/macaon_decode.cpp | 2 +- trainer/src/macaon_train.cpp | 2 +- transition_machine/src/ActionBank.cpp | 20 ++++++++++++++++---- transition_machine/src/Oracle.cpp | 3 +++ 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/decoder/src/macaon_decode.cpp b/decoder/src/macaon_decode.cpp index b005aa8..f6e9b94 100644 --- a/decoder/src/macaon_decode.cpp +++ b/decoder/src/macaon_decode.cpp @@ -51,7 +51,7 @@ po::options_description getOptionsDescription() "For each state of the Config, show its feature representation") ("readSize", po::value<int>()->default_value(0), "The number of lines of input that will be read and stored in memory at once.") - ("dictCapacity", po::value<int>()->default_value(30000), + ("dictCapacity", po::value<int>()->default_value(50000), "The maximal size of each Dict (number of differents embeddings).") ("interactive", po::value<bool>()->default_value(true), "Is the shell interactive ? Display advancement informations") diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp index aaa3f7c..7a9a2b8 100644 --- a/trainer/src/macaon_train.cpp +++ b/trainer/src/macaon_train.cpp @@ -79,7 +79,7 @@ po::options_description getOptionsDescription() "The value of the token that act as a delimiter for sequences") ("batchSize", po::value<int>()->default_value(50), "The size of each minibatch (in number of taining examples)") - ("dictCapacity", po::value<int>()->default_value(30000), + ("dictCapacity", po::value<int>()->default_value(50000), "The maximal size of each Dict (number of differents embeddings).") ("tapeToMask", po::value<std::string>()->default_value("FORM"), "The name of the Tape for which some of the elements will be masked.") diff --git a/transition_machine/src/ActionBank.cpp b/transition_machine/src/ActionBank.cpp index 7a304b6..d85fc39 100644 --- a/transition_machine/src/ActionBank.cpp +++ b/transition_machine/src/ActionBank.cpp @@ -561,14 +561,26 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na auto undo = [dist](Config &, Action::BasicAction &) { }; - auto appliable = [dist](Config &, Action::BasicAction &) + auto appliable = [dist](Config & c, Action::BasicAction) { + std::string classifierName = c.pastActions.top().first; + int stateHistorySize = c.getStateHistory(classifierName).size(); + + if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK")) + return false; + + if (c.hashHistory.contains(c.computeHash())) + return false; + + if (stateHistorySize <= dist) + return false; + return true; }; - Action::BasicAction basicAction = - {Action::BasicAction::Type::Write, "", apply, undo, appliable}; + Action::BasicAction basicAction = + {Action::BasicAction::Type::Write, "", apply, undo, appliable}; - sequence.emplace_back(basicAction); + sequence.emplace_back(basicAction); } else { diff --git a/transition_machine/src/Oracle.cpp b/transition_machine/src/Oracle.cpp index 805b62c..bc54f85 100644 --- a/transition_machine/src/Oracle.cpp +++ b/transition_machine/src/Oracle.cpp @@ -97,6 +97,7 @@ void Oracle::createDatabase() str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { + return; File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; @@ -155,6 +156,7 @@ void Oracle::createDatabase() str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { + return; File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; @@ -220,6 +222,7 @@ void Oracle::createDatabase() str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle( [](Oracle * oracle) { + return; File file(oracle->filename, "r"); FILE * fd = file.getDescriptor(); char b1[1024]; -- GitLab