Skip to content
Snippets Groups Projects
Commit 0e820713 authored by Franck Dary's avatar Franck Dary
Browse files

Prepared for classifier of error detection

parent d457a01c
No related branches found
No related tags found
No related merge requests found
...@@ -51,7 +51,7 @@ po::options_description getOptionsDescription() ...@@ -51,7 +51,7 @@ po::options_description getOptionsDescription()
"For each state of the Config, show its feature representation") "For each state of the Config, show its feature representation")
("readSize", po::value<int>()->default_value(0), ("readSize", po::value<int>()->default_value(0),
"The number of lines of input that will be read and stored in memory at once.") "The number of lines of input that will be read and stored in memory at once.")
("dictCapacity", po::value<int>()->default_value(30000), ("dictCapacity", po::value<int>()->default_value(50000),
"The maximal size of each Dict (number of differents embeddings).") "The maximal size of each Dict (number of differents embeddings).")
("interactive", po::value<bool>()->default_value(true), ("interactive", po::value<bool>()->default_value(true),
"Is the shell interactive ? Display advancement informations") "Is the shell interactive ? Display advancement informations")
......
...@@ -79,7 +79,7 @@ po::options_description getOptionsDescription() ...@@ -79,7 +79,7 @@ po::options_description getOptionsDescription()
"The value of the token that act as a delimiter for sequences") "The value of the token that act as a delimiter for sequences")
("batchSize", po::value<int>()->default_value(50), ("batchSize", po::value<int>()->default_value(50),
"The size of each minibatch (in number of taining examples)") "The size of each minibatch (in number of taining examples)")
("dictCapacity", po::value<int>()->default_value(30000), ("dictCapacity", po::value<int>()->default_value(50000),
"The maximal size of each Dict (number of differents embeddings).") "The maximal size of each Dict (number of differents embeddings).")
("tapeToMask", po::value<std::string>()->default_value("FORM"), ("tapeToMask", po::value<std::string>()->default_value("FORM"),
"The name of the Tape for which some of the elements will be masked.") "The name of the Tape for which some of the elements will be masked.")
......
...@@ -561,8 +561,20 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na ...@@ -561,8 +561,20 @@ std::vector<Action::BasicAction> ActionBank::str2sequence(const std::string & na
auto undo = [dist](Config &, Action::BasicAction &) auto undo = [dist](Config &, Action::BasicAction &)
{ {
}; };
auto appliable = [dist](Config &, Action::BasicAction &) auto appliable = [dist](Config & c, Action::BasicAction)
{ {
std::string classifierName = c.pastActions.top().first;
int stateHistorySize = c.getStateHistory(classifierName).size();
if (c.getCurrentStateHistory().size() >= 2 && (c.getCurrentStateHistory().top() == "BACK" || c.getCurrentStateHistory().getElem(1) == "BACK"))
return false;
if (c.hashHistory.contains(c.computeHash()))
return false;
if (stateHistorySize <= dist)
return false;
return true; return true;
}; };
Action::BasicAction basicAction = Action::BasicAction basicAction =
......
...@@ -97,6 +97,7 @@ void Oracle::createDatabase() ...@@ -97,6 +97,7 @@ void Oracle::createDatabase()
str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("error_tagger", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle) [](Oracle * oracle)
{ {
return;
File file(oracle->filename, "r"); File file(oracle->filename, "r");
FILE * fd = file.getDescriptor(); FILE * fd = file.getDescriptor();
char b1[1024]; char b1[1024];
...@@ -155,6 +156,7 @@ void Oracle::createDatabase() ...@@ -155,6 +156,7 @@ void Oracle::createDatabase()
str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("error_morpho", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle) [](Oracle * oracle)
{ {
return;
File file(oracle->filename, "r"); File file(oracle->filename, "r");
FILE * fd = file.getDescriptor(); FILE * fd = file.getDescriptor();
char b1[1024]; char b1[1024];
...@@ -220,6 +222,7 @@ void Oracle::createDatabase() ...@@ -220,6 +222,7 @@ void Oracle::createDatabase()
str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle( str2oracle.emplace("error_parser", std::unique_ptr<Oracle>(new Oracle(
[](Oracle * oracle) [](Oracle * oracle)
{ {
return;
File file(oracle->filename, "r"); File file(oracle->filename, "r");
FILE * fd = file.getDescriptor(); FILE * fd = file.getDescriptor();
char b1[1024]; char b1[1024];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment