Skip to content
Snippets Groups Projects
Commit 5c948074 authored by Franck Dary's avatar Franck Dary
Browse files

Added optional instruction BanExamples in Classifier's definition

parent 5f7e4933
No related branches found
No related tags found
No related merge requests found
...@@ -152,7 +152,7 @@ void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, c ...@@ -152,7 +152,7 @@ void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, c
break; break;
if (buffer[std::strlen(buffer)-1] == '\n') if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0'; buffer[std::strlen(buffer)-1] = '\0';
if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){})) if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto){}))
continue; continue;
if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm) if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
......
...@@ -26,11 +26,12 @@ class Classifier ...@@ -26,11 +26,12 @@ class Classifier
std::vector<std::string> states; std::vector<std::string> states;
std::filesystem::path path; std::filesystem::path path;
bool regression{false}; bool regression{false};
std::vector<std::tuple<std::string, std::string, std::string>> bannedExamples;
LossFunction lossFct; LossFunction lossFct;
private : private :
void initNeuralNetwork(const std::vector<std::string> & definition); void initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
std::string getLastFilename() const; std::string getLastFilename() const;
std::string getBestFilename() const; std::string getBestFilename() const;
...@@ -54,6 +55,7 @@ class Classifier ...@@ -54,6 +55,7 @@ class Classifier
void saveLast(); void saveLast();
bool isRegression() const; bool isRegression() const;
LossFunction & getLossFunction(); LossFunction & getLossFunction();
bool exampleIsBanned(const Config & config);
}; };
#endif #endif
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
{ {
this->name = name; this->name = name;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) std::size_t curIndex = 0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[curIndex], [this,&path,&curIndex](auto sm)
{ {
auto splited = util::split(sm.str(1), ' '); auto splited = util::split(sm.str(1), ' ');
...@@ -35,14 +36,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -35,14 +36,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
this->transitionSets.emplace(stateName, new TransitionSet(tsFiles)); this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
} }
} }
curIndex++;
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets) for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0; lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm) if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [this,&curIndex](auto sm)
{ {
auto pairs = util::split(sm.str(1), ' '); auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs) for (auto & it : pairs)
...@@ -58,10 +60,18 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -58,10 +60,18 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it)); util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
} }
} }
curIndex++;
})) }))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition); while (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BanExamples :|)(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Target\\{(.*)\\}(?:(?:\\s|\\t)*)Value\\{(.*)\\}(?:(?:\\s|\\t)*)"), definition[curIndex], [this,&curIndex](auto sm)
{
bannedExamples.emplace_back(sm.str(1), sm.str(2), sm.str(3));
curIndex++;
}));
initNeuralNetwork(definition, curIndex);
if (train) if (train)
getNN()->train(); getNN()->train();
...@@ -120,14 +130,12 @@ const std::string & Classifier::getName() const ...@@ -120,14 +130,12 @@ const std::string & Classifier::getName() const
return name; return name;
} }
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex)
{ {
std::map<std::string,std::size_t> nbOutputsPerState; std::map<std::string,std::size_t> nbOutputsPerState;
for (auto & it : this->transitionSets) for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size(); nbOutputsPerState[it.first] = it.second->size();
std::size_t curIndex = 2;
std::string networkType; std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
{ {
...@@ -204,7 +212,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s ...@@ -204,7 +212,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
for (; curIndex < definition.size(); curIndex++) for (; curIndex < definition.size(); curIndex++)
{ {
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){})) if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto){}))
{ {
curIndex++; curIndex++;
break; break;
...@@ -288,3 +296,19 @@ LossFunction & Classifier::getLossFunction() ...@@ -288,3 +296,19 @@ LossFunction & Classifier::getLossFunction()
return lossFct; return lossFct;
} }
bool Classifier::exampleIsBanned(const Config & config)
{
for (auto t : bannedExamples)
{
auto column = std::get<0>(t);
auto splited = util::split(std::get<1>(t), '.');
auto value = std::get<2>(t);
auto object = Config::str2object(splited[0]);
int index = std::stoi(splited[1]);
if (config.getAsFeature(column, config.getRelativeWordIndex(object, index)) == value)
return true;
}
return false;
}
...@@ -75,7 +75,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) ...@@ -75,7 +75,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
})) }))
util::myThrow("No predictions specified"); util::myThrow("No predictions specified");
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto)
{ {
strategyDefinition.clear(); strategyDefinition.clear();
if (lines[curLine] != "{") if (lines[curLine] != "{")
......
...@@ -122,6 +122,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -122,6 +122,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
} }
std::vector<long> goldIndexes; std::vector<long> goldIndexes;
bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
if (machine.getClassifier(config.getState())->isRegression()) if (machine.getClassifier(config.getState())->isRegression())
{ {
...@@ -147,13 +148,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -147,13 +148,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
} }
totalNbExamples += context.size(); if (!exampleIsBanned)
if (totalNbExamples >= (int)safetyNbExamplesMax) {
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
}
config.setChosenActionScore(bestScore); config.setChosenActionScore(bestScore);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment