Skip to content
Snippets Groups Projects
Commit 5c948074 authored by Franck Dary's avatar Franck Dary
Browse files

Added optional instruction BanExamples in Classifier's definition

parent 5f7e4933
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,7 @@ void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, c
break;
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto){}))
continue;
if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
......
......@@ -26,11 +26,12 @@ class Classifier
std::vector<std::string> states;
std::filesystem::path path;
bool regression{false};
std::vector<std::tuple<std::string, std::string, std::string>> bannedExamples;
LossFunction lossFct;
private :
void initNeuralNetwork(const std::vector<std::string> & definition);
void initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
std::string getLastFilename() const;
std::string getBestFilename() const;
......@@ -54,6 +55,7 @@ class Classifier
void saveLast();
bool isRegression() const;
LossFunction & getLossFunction();
bool exampleIsBanned(const Config & config);
};
#endif
......@@ -6,7 +6,8 @@
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
{
this->name = name;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
std::size_t curIndex = 0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[curIndex], [this,&path,&curIndex](auto sm)
{
auto splited = util::split(sm.str(1), ' ');
......@@ -35,14 +36,15 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
this->transitionSets.emplace(stateName, new TransitionSet(tsFiles));
}
}
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[0], "(Transitions :) {tsFile1.ts tsFile2.ts...}"));
for (auto & it : this->transitionSets)
lossMultipliers[it.first] = 1.0;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[1], [this](auto sm)
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:LossMultiplier :|)(?:(?:\\s|\\t)*)\\{(.*)\\}"), definition[curIndex], [this,&curIndex](auto sm)
{
auto pairs = util::split(sm.str(1), ' ');
for (auto & it : pairs)
......@@ -58,10 +60,18 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
util::myThrow(fmt::format("caugh '{}' in '{}'", e.what(), it));
}
}
curIndex++;
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition);
while (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:BanExamples :|)(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Target\\{(.*)\\}(?:(?:\\s|\\t)*)Value\\{(.*)\\}(?:(?:\\s|\\t)*)"), definition[curIndex], [this,&curIndex](auto sm)
{
bannedExamples.emplace_back(sm.str(1), sm.str(2), sm.str(3));
curIndex++;
}));
initNeuralNetwork(definition, curIndex);
if (train)
getNN()->train();
......@@ -120,14 +130,12 @@ const std::string & Classifier::getName() const
return name;
}
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::size_t curIndex)
{
std::map<std::string,std::size_t> nbOutputsPerState;
for (auto & it : this->transitionSets)
nbOutputsPerState[it.first] = it.second->size();
std::size_t curIndex = 2;
std::string networkType;
if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Network type :|)(?:(?:\\s|\\t)*)(\\S+)"), definition[curIndex], [&curIndex,&networkType](auto sm)
{
......@@ -204,7 +212,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
for (; curIndex < definition.size(); curIndex++)
{
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto sm){}))
if (util::doIfNameMatch(endRegex,definition[curIndex],[](auto){}))
{
curIndex++;
break;
......@@ -288,3 +296,19 @@ LossFunction & Classifier::getLossFunction()
return lossFct;
}
bool Classifier::exampleIsBanned(const Config & config)
{
for (auto t : bannedExamples)
{
auto column = std::get<0>(t);
auto splited = util::split(std::get<1>(t), '.');
auto value = std::get<2>(t);
auto object = Config::str2object(splited[0]);
int index = std::stoi(splited[1]);
if (config.getAsFeature(column, config.getRelativeWordIndex(object, index)) == value)
return true;
}
return false;
}
......@@ -75,7 +75,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
}))
util::myThrow("No predictions specified");
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto sm)
if (!util::doIfNameMatch(std::regex("Strategy"), lines[curLine++], [this,&lines,&curLine](auto)
{
strategyDefinition.clear();
if (lines[curLine] != "{")
......
......@@ -122,6 +122,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
}
std::vector<long> goldIndexes;
bool exampleIsBanned = machine.getClassifier(config.getState())->exampleIsBanned(config);
if (machine.getClassifier(config.getState())->isRegression())
{
......@@ -147,13 +148,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
}
totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
if (!exampleIsBanned)
{
totalNbExamples += context.size();
if (totalNbExamples >= (int)safetyNbExamplesMax)
util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax)));
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
examplesPerState[config.getState()].addContext(context);
examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes);
examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle);
}
config.setChosenActionScore(bestScore);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment