Commit 41c1ae12 authored by Franck Dary's avatar Franck Dary
Browse files

Support for multiple classifiers

parent 04dd8e56
......@@ -37,14 +37,14 @@ void Beam::update(ReadingMachine & machine, bool debug)
ended = false;
auto & classifier = *machine.getClassifier();
auto & classifier = *machine.getClassifier(elements[index].config.getState());
classifier.setState(elements[index].config.getState());
if (machine.hasSplitWordTransitionSet())
elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config);
auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config);
elements[index].config.setAppliableTransitions(appliableTransitions);
auto context = classifier.getNN()->extractContext(elements[index].config).back();
......@@ -95,7 +95,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet(elements[index].config.getState()).getTransition(i)->getName());
toPrint.emplace_back(std::make_pair(score,nicePrint));
}
std::sort(toPrint.rbegin(), toPrint.rend());
......@@ -118,11 +118,11 @@ void Beam::update(ReadingMachine & machine, bool debug)
continue;
auto & config = element.config;
auto & classifier = *machine.getClassifier();
auto & classifier = *machine.getClassifier(config.getState());
classifier.setState(config.getState());
auto * transition = machine.getTransitionSet().getTransition(element.nextTransition);
auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
transition->apply(config);
config.addToHistory(transition->getName());
......
......@@ -39,11 +39,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
} catch(std::exception & e) {util::myThrow(e.what());}
baseConfig = beam[0].config;
machine.getClassifier()->setState(baseConfig.getState());
machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState());
if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig);
machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -87,7 +87,7 @@ int MacaonDecode::main()
try
{
ReadingMachine machine(machinePath, modelPaths);
ReadingMachine machine(machinePath, false);
Decoder decoder(machine);
BaseConfig config(mcd, inputTSV, inputTXT);
......
......@@ -2,6 +2,7 @@
#define CLASSIFIER__H
#include <string>
#include <filesystem>
#include "TransitionSet.hpp"
#include "NeuralNetwork.hpp"
......@@ -21,25 +22,33 @@ class Classifier
std::unique_ptr<torch::optim::Optimizer> optimizer;
std::string optimizerType, optimizerParameters;
std::string state;
std::vector<std::string> states;
std::filesystem::path path;
private :
void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path);
void initNeuralNetwork(const std::vector<std::string> & definition);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
std::string getLastFilename() const;
std::string getBestFilename() const;
public :
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition);
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
TransitionSet & getTransitionSet();
NeuralNetwork & getNN();
const std::string & getName() const;
int getNbParameters() const;
void resetOptimizer();
void loadOptimizer(std::filesystem::path path);
void saveOptimizer(std::filesystem::path path);
void loadOptimizer();
void saveOptimizer();
torch::optim::Optimizer & getOptimizer();
void setState(const std::string & state);
float getLossMultiplier();
const std::vector<std::string> & getStates() const;
void saveDicts();
void saveBest();
void saveLast();
};
#endif
......@@ -17,28 +17,28 @@ class ReadingMachine
std::string name;
std::filesystem::path path;
std::unique_ptr<Classifier> classifier;
std::vector<std::unique_ptr<Classifier>> classifiers;
std::map<std::string, int> state2classifier;
std::vector<std::string> strategyDefinition;
std::vector<std::string> classifierDefinition;
std::string classifierName;
std::vector<std::vector<std::string>> classifierDefinitions;
std::vector<std::string> classifierNames;
std::set<std::string> predicted;
bool train;
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
private :
void readFromFile(std::filesystem::path path);
void save(const std::string & modelNameTemplate) const;
public :
ReadingMachine(std::filesystem::path path);
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models);
TransitionSet & getTransitionSet();
ReadingMachine(std::filesystem::path path, bool train);
TransitionSet & getTransitionSet(const std::string & state);
TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const;
const std::vector<std::string> & getStrategyDefinition() const;
Classifier * getClassifier();
Classifier * getClassifier(const std::string & state);
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
......@@ -46,11 +46,12 @@ class ReadingMachine
void saveBest() const;
void saveLast() const;
void saveDicts() const;
void loadDicts();
void loadLastSaved();
void setCountOcc(bool countOcc);
void removeRareDictElements(float rarityThreshold);
void resetClassifier();
void resetClassifiers();
void loadPretrainedClassifiers();
int getNbParameters() const;
void resetOptimizers();
};
#endif
......@@ -3,7 +3,7 @@
#include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition)
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
{
this->name = name;
if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm)
......@@ -13,12 +13,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
for (auto & ss : splited)
{
std::vector<std::string> tsFiles;
std::vector<std::string> states;
for (auto & elem : util::split(ss, ','))
if (std::filesystem::path(elem).extension().empty())
states.emplace_back(elem);
else
tsFiles.emplace_back(path.parent_path() / elem);
tsFiles.emplace_back(path / elem);
if (tsFiles.empty())
util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss));
if (states.empty())
......@@ -58,7 +57,19 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition, path.parent_path());
initNeuralNetwork(definition);
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
if (!train)
torch::load(getNN(), getBestFilename());
else if (std::filesystem::exists(getLastFilename()))
{
torch::load(getNN(), getLastFilename());
resetOptimizer();
loadOptimizer();
}
}
int Classifier::getNbParameters() const
......@@ -89,7 +100,7 @@ const std::string & Classifier::getName() const
return name;
}
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path)
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
{
std::map<std::string,std::size_t> nbOutputsPerState;
for (auto & it : this->transitionSets)
......@@ -108,7 +119,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition,
if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState, path);
initModular(definition, curIndex, nbOutputsPerState);
else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
......@@ -120,14 +131,16 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition,
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers)));
}
void Classifier::loadOptimizer(std::filesystem::path path)
void Classifier::loadOptimizer()
{
torch::load(*optimizer, path);
auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name));
if (std::filesystem::exists(optimizerPath))
torch::load(*optimizer, optimizerPath);
}
void Classifier::saveOptimizer(std::filesystem::path path)
void Classifier::saveOptimizer()
{
torch::save(*optimizer, path);
torch::save(*optimizer, fmt::format("{}/{}_optimizer.pt", path.string(), name));
}
torch::optim::Optimizer & Classifier::getOptimizer()
......@@ -141,7 +154,7 @@ void Classifier::setState(const std::string & state)
nn->setState(state);
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path)
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
......@@ -188,3 +201,34 @@ float Classifier::getLossMultiplier()
return lossMultipliers.at(state);
}
const std::vector<std::string> & Classifier::getStates() const
{
return states;
}
void Classifier::saveDicts()
{
getNN()->saveDicts(path);
}
std::string Classifier::getBestFilename() const
{
return fmt::format("{}/{}_best.pt", path.string(), name);
}
std::string Classifier::getLastFilename() const
{
return fmt::format("{}/{}_last.pt", path.string(), name);
}
void Classifier::saveBest()
{
torch::save(getNN(), getBestFilename());
}
void Classifier::saveLast()
{
torch::save(getNN(), getLastFilename());
saveOptimizer();
}
#include "ReadingMachine.hpp"
#include "util.hpp"
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
ReadingMachine::ReadingMachine(std::filesystem::path path, bool train) : path(path), train(train)
{
readFromFile(path);
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path)
{
readFromFile(path);
loadDicts();
trainMode(false);
classifier->getNN()->registerEmbeddings();
classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1)
util::myThrow("having more than one model file is not supported");
try
{
torch::load(classifier->getNN(), models[0]);
} catch (std::exception & e)
{
util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what()));
}
}
void ReadingMachine::readFromFile(std::filesystem::path path)
{
std::FILE * file = std::fopen(path.c_str(), "r");
......@@ -57,22 +36,28 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];}))
util::myThrow("No name specified");
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm)
while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine], [this,path,&lines,&curLine](auto sm)
{
classifierDefinition.clear();
classifierName = sm.str(1);
curLine++;
classifierDefinitions.emplace_back();
classifierNames.emplace_back(sm.str(1));
if (lines[curLine] != "{")
util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine]));
for (curLine++; curLine < lines.size(); curLine++)
{
if (lines[curLine] == "}")
{
curLine++;
break;
classifierDefinition.emplace_back(lines[curLine]);
}
classifierDefinitions.back().emplace_back(lines[curLine]);
}
classifier.reset(new Classifier(sm.str(1), path, classifierDefinition));
classifiers.emplace_back(new Classifier(sm.str(1), path.parent_path(), classifierDefinitions.back(), train));
for (auto state : classifiers.back()->getStates())
state2classifier[state] = classifiers.size()-1;
}));
if (!classifier.get())
if (classifiers.empty())
util::myThrow("No Classifier specified");
util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm)
......@@ -108,9 +93,9 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
TransitionSet & ReadingMachine::getTransitionSet()
TransitionSet & ReadingMachine::getTransitionSet(const std::string & state)
{
return classifier->getTransitionSet();
return classifiers[state2classifier.at(state)]->getTransitionSet();
}
bool ReadingMachine::hasSplitWordTransitionSet() const
......@@ -128,37 +113,29 @@ const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const
return strategyDefinition;
}
Classifier * ReadingMachine::getClassifier()
Classifier * ReadingMachine::getClassifier(const std::string & state)
{
return classifier.get();
return classifiers[state2classifier.at(state)].get();
}
void ReadingMachine::saveDicts() const
{
classifier->getNN()->saveDicts(path.parent_path());
}
void ReadingMachine::loadDicts()
{
classifier->getNN()->loadDicts(path.parent_path());
}
void ReadingMachine::save(const std::string & modelNameTemplate) const
{
saveDicts();
auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier);
for (auto & classifier : classifiers)
classifier->saveDicts();
}
void ReadingMachine::saveBest() const
{
save(defaultModelFilename);
saveDicts();
for (auto & classifier : classifiers)
classifier->saveBest();
}
void ReadingMachine::saveLast() const
{
save(lastModelFilename);
saveDicts();
for (auto & classifier : classifiers)
classifier->saveLast();
}
bool ReadingMachine::isPredicted(const std::string & columnName) const
......@@ -173,34 +150,47 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
void ReadingMachine::trainMode(bool isTrainMode)
{
classifier->getNN()->train(isTrainMode);
for (auto & classifier : classifiers)
classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::setDictsState(Dict::State state)
{
classifier->getNN()->setDictsState(state);
for (auto & classifier : classifiers)
classifier->getNN()->setDictsState(state);
}
void ReadingMachine::loadLastSaved()
void ReadingMachine::setCountOcc(bool countOcc)
{
auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, ""));
if (!lastSavedModel.empty())
torch::load(classifier->getNN(), lastSavedModel[0]);
for (auto & classifier : classifiers)
classifier->getNN()->setCountOcc(countOcc);
}
void ReadingMachine::setCountOcc(bool countOcc)
void ReadingMachine::removeRareDictElements(float rarityThreshold)
{
classifier->getNN()->setCountOcc(countOcc);
for (auto & classifier : classifiers)
classifier->getNN()->removeRareDictElements(rarityThreshold);
}
void ReadingMachine::removeRareDictElements(float rarityThreshold)
void ReadingMachine::resetClassifiers()
{
classifier->getNN()->removeRareDictElements(rarityThreshold);
for (unsigned int i = 0; i < classifiers.size(); i++)
classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
}
int ReadingMachine::getNbParameters() const
{
int sum = 0;
for (auto & classifier : classifiers)
sum += classifier->getNbParameters();
return sum;
}
void ReadingMachine::resetClassifier()
void ReadingMachine::resetOptimizers()
{
classifier.reset(new Classifier(classifierName, path, classifierDefinition));
loadDicts();
for (auto & classifier : classifiers)
classifier->resetOptimizer();
}
......@@ -18,7 +18,9 @@ void DictHolder::saveDict(std::filesystem::path path)
void DictHolder::loadDict(std::filesystem::path path)
{
dict.reset(new Dict((path / filename()).c_str(), dict->getState()));
auto dictPath = path / filename();
if (std::filesystem::exists(dictPath))
dict.reset(new Dict(dictPath.c_str(), dict->getState()));
}
Dict & DictHolder::getDict()
......
......@@ -156,7 +156,7 @@ int MacaonTrain::main()
try
{
ReadingMachine machine(machinePath.string());
ReadingMachine machine(machinePath.string(), true);
BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
......@@ -164,14 +164,6 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize, lossFunction);
Decoder decoder(machine);
if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty())
{
machine.loadDicts();
machine.getClassifier()->getNN()->registerEmbeddings();
machine.loadLastSaved();
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
}
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info";
......@@ -195,13 +187,6 @@ int MacaonTrain::main()
std::fclose(f);
}
auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer";
if (std::filesystem::exists(trainInfos))
{
machine.getClassifier()->resetOptimizer();
machine.getClassifier()->loadOptimizer(optimizerCheckpoint);
}
for (; currentEpoch < nbEpoch; currentEpoch++)
{
bool saved = false;
......@@ -231,14 +216,12 @@ int MacaonTrain::main()
{
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters))
{
machine.resetClassifier();
machine.resetClassifiers();
machine.trainMode(currentEpoch == 0);
machine.getClassifier()->getNN()->registerEmbeddings();
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters()));
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
}
machine.getClassifier()->resetOptimizer();
machine.resetOptimizers();
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save))
{
......@@ -290,8 +273,9 @@ int MacaonTrain::main()
bestDevScore = devScoreMean;
machine.saveBest();
}
machine.saveLast();
machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : "");
......
......@@ -93,7 +93,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
config.addPredicted(machine.getPredicted());
config.setStrategy(machine.getStrategyDefinition());
config.setState(config.getStrategy().getInitialState());
machine.getClassifier()->setState(config.getState());
machine.getClassifier(config.getState())->setState(config.getState());
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle);
......@@ -111,14 +111,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config);
auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config);
config.setAppliableTransitions(appliableTransitions);
std::vector<std::vector<long>> context;
try
{
context = machine.getClassifier()->getNN()->extractContext(config);
context = machine.getClassifier(config.getState())->getNN()->extractContext(config);
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
......@@ -126,14 +127,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
Transition * transition = nullptr;
auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle);