diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 7739b1de6096be6fedb313b8d81ed974ee493e99..a6635649f35bd3f86e46055ce5cace255470c378 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -152,7 +152,7 @@ void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, c break; if (buffer[std::strlen(buffer)-1] == '\n') buffer[std::strlen(buffer)-1] = '\0'; - if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){})) + if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto){})) continue; if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm) diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index f8f79fe88e87c10f55c0279e9e69a3ac0cde451d..17bfc8487da18fa2079b295a32a69c5a80edb849 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -26,11 +26,12 @@ class Classifier std::vector<std::string> states; std::filesystem::path path; bool regression{false}; + std::vector<std::tuple<std::string, std::string, std::string>> bannedExamples; LossFunction lossFct; private : - void initNeuralNetwork(const std::vector<std::string> & definition); + void initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex); void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); std::string getLastFilename() const; std::string getBestFilename() const; @@ -54,6 +55,7 @@ class Classifier void saveLast(); bool isRegression() const; LossFunction & getLossFunction(); + bool exampleIsBanned(const Config & config); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 76bb7376dc2aa5dcde8c142e2f36dd2b63d19168..aa11dc9704f6d72c1c1159ea1953ec8fa82ddcac 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -6,7 +6,8 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) { this->name = name; - if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) + std::size_t curIndex = 0; + if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[curIndex], [this,&path,&curIndex](auto sm) { auto splited = util::split(sm.str(1), ' '); @@ -35,14 +36,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std this->transitionSets.emplace(stateName, new TransitionSet(tsFiles)); } } - + + curIndex++; })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}")); for (auto & it : this->transitionSets) lossMultipliers[it.first] = 1.0; - if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm) + if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [this,&curIndex](auto sm) { auto pairs = util::split(sm.str(1), ' '); for (auto & it : pairs) @@ -58,10 +60,18 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it)); } } + + curIndex++; })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); - initNeuralNetwork(definition); + while (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BanExamples :|)(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Target\\{(.*)\\}(?:(?:\\s|\\t)*)Value\\{(.*)\\}(?:(?:\\s|\\t)*)"), definition[curIndex], [this,&curIndex](auto sm) + { + bannedExamples.emplace_back(sm.str(1), sm.str(2), sm.str(3)); + curIndex++; + })); + + initNeuralNetwork(definition, curIndex); if (train) getNN()->train(); @@ -120,14 +130,12 @@ const std::string & Classifier::getName() const return name; } -void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) +void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex) { std::map<std::string,std::size_t> nbOutputsPerState; for (auto & it : this->transitionSets) nbOutputsPerState[it.first] = it.second->size(); - std::size_t curIndex = 2; - std::string networkType; if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm) { @@ -204,7 +212,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s for (; curIndex < definition.size(); curIndex++) { - if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){})) + if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto){})) { curIndex++; break; @@ -288,3 +296,19 @@ LossFunction & Classifier::getLossFunction() return lossFct; } +bool Classifier::exampleIsBanned(const Config & config) +{ + for (auto t : bannedExamples) + { + auto column = std::get<0>(t); + auto splited = util::split(std::get<1>(t), '.'); + auto value = std::get<2>(t); + auto object = Config::str2object(splited[0]); + int index = std::stoi(splited[1]); + if (config.getAsFeature(column, config.getRelativeWordIndex(object, index)) == value) + return true; + } + + return false; +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index 216088dea204098d719b3ea3a052154671fbe5aa..4336d3adab6957dfb50dd9c1e7fd9e2f5221cda9 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -75,7 +75,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path) })) util::myThrow("No predictions specified"); - if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm) + if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto) { strategyDefinition.clear(); if (lines[curLine] != "{") diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 7f32d31bf7cf775a445c3a3beb1ca0fea6349004..ddb9d0c881260ebe43434ae2ef8d21d3b2ffdf6b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -122,6 +122,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } std::vector<long> goldIndexes; + bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config); if (machine.getClassifier(config.getState())->isRegression()) { @@ -147,13 +148,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } - totalNbExamples += context.size(); - if (totalNbExamples >= (int)safetyNbExamplesMax) - util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); + if (!exampleIsBanned) + { + totalNbExamples += context.size(); + if (totalNbExamples >= (int)safetyNbExamplesMax) + util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); - examplesPerState[config.getState()].addContext(context); - examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); - examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); + examplesPerState[config.getState()].addContext(context); + examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); + examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); + } config.setChosenActionScore(bestScore);