diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 08606efd0a6db74a796a5310a6b1e9f16428faaa..ee2fd0df190e1a38c669caaad838db0ab54c90a7 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -37,14 +37,14 @@ void Beam::update(ReadingMachine & machine, bool debug) ended = false; - auto & classifier = *machine.getClassifier(); + auto & classifier = *machine.getClassifier(elements[index].config.getState()); classifier.setState(elements[index].config.getState()); if (machine.hasSplitWordTransitionSet()) elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions)); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config); + auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config); elements[index].config.setAppliableTransitions(appliableTransitions); auto context = classifier.getNN()->extractContext(elements[index].config).back(); @@ -95,7 +95,7 @@ void Beam::update(ReadingMachine & machine, bool debug) for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); - std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName()); + std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet(elements[index].config.getState()).getTransition(i)->getName()); toPrint.emplace_back(std::make_pair(score,nicePrint)); } std::sort(toPrint.rbegin(), toPrint.rend()); @@ -118,11 +118,11 @@ void Beam::update(ReadingMachine & machine, bool debug) continue; auto & config = element.config; - auto & classifier = *machine.getClassifier(); + auto & classifier = *machine.getClassifier(config.getState()); classifier.setState(config.getState()); - auto * transition = machine.getTransitionSet().getTransition(element.nextTransition); + auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition); transition->apply(config); config.addToHistory(transition->getName()); diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index f9c437787e991b9a9b99eef335eedc7a0913c3f3..7739b1de6096be6fedb313b8d81ed974ee493e99 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -39,11 +39,11 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh } catch(std::exception & e) {util::myThrow(e.what());} baseConfig = beam[0].config; - machine.getClassifier()->setState(baseConfig.getState()); + machine.getClassifier(baseConfig.getState())->setState(baseConfig.getState()); - if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) + if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { - machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig); + machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig); if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index a3ddb73743fb59a97ebf12890d99fe5c6da2e14a..bda35a890ec512d970875dac82a01222b37c88d2 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -87,7 +87,7 @@ int MacaonDecode::main() try { - ReadingMachine machine(machinePath, modelPaths); + ReadingMachine machine(machinePath, false); Decoder decoder(machine); BaseConfig config(mcd, inputTSV, inputTXT); diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index a5f7d21b90aa02f659fbdd6be26afdb3def9521f..3e5e9507175db5cc28e3af391dc622da6c5f4ec2 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -2,6 +2,7 @@ #define CLASSIFIER__H #include <string> +#include <filesystem> #include "TransitionSet.hpp" #include "NeuralNetwork.hpp" @@ -21,25 +22,33 @@ class Classifier std::unique_ptr<torch::optim::Optimizer> optimizer; std::string optimizerType, optimizerParameters; std::string state; + std::vector<std::string> states; + std::filesystem::path path; private : - void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path); - void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path); + void initNeuralNetwork(const std::vector<std::string> & definition); + void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState); + std::string getLastFilename() const; + std::string getBestFilename() const; public : - Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition); + Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train); TransitionSet & getTransitionSet(); NeuralNetwork & getNN(); const std::string & getName() const; int getNbParameters() const; void resetOptimizer(); - void loadOptimizer(std::filesystem::path path); - void saveOptimizer(std::filesystem::path path); + void loadOptimizer(); + void saveOptimizer(); torch::optim::Optimizer & getOptimizer(); void setState(const std::string & state); float getLossMultiplier(); + const std::vector<std::string> & getStates() const; + void saveDicts(); + void saveBest(); + void saveLast(); }; #endif diff --git a/reading_machine/include/ReadingMachine.hpp b/reading_machine/include/ReadingMachine.hpp index 13f5cbc1108c31ac858e99c5a2490627491e67e3..a974ac47715775054b9a84216ca7a08834d284de 100644 --- a/reading_machine/include/ReadingMachine.hpp +++ b/reading_machine/include/ReadingMachine.hpp @@ -17,28 +17,28 @@ class ReadingMachine std::string name; std::filesystem::path path; - std::unique_ptr<Classifier> classifier; + std::vector<std::unique_ptr<Classifier>> classifiers; + std::map<std::string, int> state2classifier; std::vector<std::string> strategyDefinition; - std::vector<std::string> classifierDefinition; - std::string classifierName; + std::vector<std::vector<std::string>> classifierDefinitions; + std::vector<std::string> classifierNames; std::set<std::string> predicted; + bool train; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; private : void readFromFile(std::filesystem::path path); - void save(const std::string & modelNameTemplate) const; public : - ReadingMachine(std::filesystem::path path); - ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models); - TransitionSet & getTransitionSet(); + ReadingMachine(std::filesystem::path path, bool train); + TransitionSet & getTransitionSet(const std::string & state); TransitionSet & getSplitWordTransitionSet(); bool hasSplitWordTransitionSet() const; const std::vector<std::string> & getStrategyDefinition() const; - Classifier * getClassifier(); + Classifier * getClassifier(const std::string & state); bool isPredicted(const std::string & columnName) const; const std::set<std::string> & getPredicted() const; void trainMode(bool isTrainMode); @@ -46,11 +46,12 @@ class ReadingMachine void saveBest() const; void saveLast() const; void saveDicts() const; - void loadDicts(); - void loadLastSaved(); void setCountOcc(bool countOcc); void removeRareDictElements(float rarityThreshold); - void resetClassifier(); + void resetClassifiers(); + void loadPretrainedClassifiers(); + int getNbParameters() const; + void resetOptimizers(); }; #endif diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 323bdb0e1301c7e4fd96c94fb610bed103aff3b3..c22e21b798b9dbf0d6ccd67c2f151dcd7b915eae 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -3,7 +3,7 @@ #include "RandomNetwork.hpp" #include "ModularNetwork.hpp" -Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition) +Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) { this->name = name; if (!util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Transitions :|)(?:(?:\\s|\\t)*)\\{(.+)\\}"), definition[0], [this,&path](auto sm) @@ -13,12 +13,11 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std for (auto & ss : splited) { std::vector<std::string> tsFiles; - std::vector<std::string> states; for (auto & elem : util::split(ss, ',')) if (std::filesystem::path(elem).extension().empty()) states.emplace_back(elem); else - tsFiles.emplace_back(path.parent_path() / elem); + tsFiles.emplace_back(path / elem); if (tsFiles.empty()) util::myThrow(fmt::format("invalid '{}' no .ts files specified", ss)); if (states.empty()) @@ -58,7 +57,19 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}")); - initNeuralNetwork(definition, path.parent_path()); + initNeuralNetwork(definition); + + getNN()->loadDicts(path); + getNN()->registerEmbeddings(); + + if (!train) + torch::load(getNN(), getBestFilename()); + else if (std::filesystem::exists(getLastFilename())) + { + torch::load(getNN(), getLastFilename()); + resetOptimizer(); + loadOptimizer(); + } } int Classifier::getNbParameters() const @@ -89,7 +100,7 @@ const std::string & Classifier::getName() const return name; } -void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path) +void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) { std::map<std::string,std::size_t> nbOutputsPerState; for (auto & it : this->transitionSets) @@ -108,7 +119,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, if (networkType == "Random") this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState)); else if (networkType == "Modular") - initModular(definition, curIndex, nbOutputsPerState, path); + initModular(definition, curIndex, nbOutputsPerState); else util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType)); @@ -120,14 +131,16 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers))); } -void Classifier::loadOptimizer(std::filesystem::path path) +void Classifier::loadOptimizer() { - torch::load(*optimizer, path); + auto optimizerPath = std::filesystem::path(fmt::format("{}/{}_optimizer.pt", path.string(), name)); + if (std::filesystem::exists(optimizerPath)) + torch::load(*optimizer, optimizerPath); } -void Classifier::saveOptimizer(std::filesystem::path path) +void Classifier::saveOptimizer() { - torch::save(*optimizer, path); + torch::save(*optimizer, fmt::format("{}/{}_optimizer.pt", path.string(), name)); } torch::optim::Optimizer & Classifier::getOptimizer() @@ -141,7 +154,7 @@ void Classifier::setState(const std::string & state) nn->setState(state); } -void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path) +void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState) { std::string anyBlanks = "(?:(?:\\s|\\t)*)"; std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks)); @@ -188,3 +201,34 @@ float Classifier::getLossMultiplier() return lossMultipliers.at(state); } +const std::vector<std::string> & Classifier::getStates() const +{ + return states; +} + +void Classifier::saveDicts() +{ + getNN()->saveDicts(path); +} + +std::string Classifier::getBestFilename() const +{ + return fmt::format("{}/{}_best.pt", path.string(), name); +} + +std::string Classifier::getLastFilename() const +{ + return fmt::format("{}/{}_last.pt", path.string(), name); +} + +void Classifier::saveBest() +{ + torch::save(getNN(), getBestFilename()); +} + +void Classifier::saveLast() +{ + torch::save(getNN(), getLastFilename()); + saveOptimizer(); +} + diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index f5fd3c4c0c063d0b7b5e29744e9d01a4f29d6d20..216088dea204098d719b3ea3a052154671fbe5aa 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -1,32 +1,11 @@ #include "ReadingMachine.hpp" #include "util.hpp" -ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) +ReadingMachine::ReadingMachine(std::filesystem::path path, bool train) : path(path), train(train) { readFromFile(path); } -ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models) : path(path) -{ - readFromFile(path); - - loadDicts(); - trainMode(false); - classifier->getNN()->registerEmbeddings(); - classifier->getNN()->to(NeuralNetworkImpl::device); - - if (models.size() > 1) - util::myThrow("having more than one model file is not supported"); - - try - { - torch::load(classifier->getNN(), models[0]); - } catch (std::exception & e) - { - util::myThrow(fmt::format("error when loading '{}' : {}", models[0].string(), e.what())); - } -} - void ReadingMachine::readFromFile(std::filesystem::path path) { std::FILE * file = std::fopen(path.c_str(), "r"); @@ -57,22 +36,28 @@ void ReadingMachine::readFromFile(std::filesystem::path path) if (!util::doIfNameMatch(std::regex("Name : (.+)"), lines[curLine++], [this](auto sm){name = sm[1];})) util::myThrow("No name specified"); - while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine++], [this,path,&lines,&curLine](auto sm) + while (util::doIfNameMatch(std::regex("Classifier : (.+)"), lines[curLine], [this,path,&lines,&curLine](auto sm) { - classifierDefinition.clear(); - classifierName = sm.str(1); + curLine++; + classifierDefinitions.emplace_back(); + classifierNames.emplace_back(sm.str(1)); if (lines[curLine] != "{") util::myThrow(fmt::format("Expected '{}', got '{}' instead", "{", lines[curLine])); for (curLine++; curLine < lines.size(); curLine++) { if (lines[curLine] == "}") + { + curLine++; break; - classifierDefinition.emplace_back(lines[curLine]); + } + classifierDefinitions.back().emplace_back(lines[curLine]); } - classifier.reset(new Classifier(sm.str(1), path, classifierDefinition)); + classifiers.emplace_back(new Classifier(sm.str(1), path.parent_path(), classifierDefinitions.back(), train)); + for (auto state : classifiers.back()->getStates()) + state2classifier[state] = classifiers.size()-1; })); - if (!classifier.get()) + if (classifiers.empty()) util::myThrow("No Classifier specified"); util::doIfNameMatch(std::regex("Splitwords : (.+)"), lines[curLine], [this,path,&curLine](auto sm) @@ -108,9 +93,9 @@ void ReadingMachine::readFromFile(std::filesystem::path path) } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));} } -TransitionSet & ReadingMachine::getTransitionSet() +TransitionSet & ReadingMachine::getTransitionSet(const std::string & state) { - return classifier->getTransitionSet(); + return classifiers[state2classifier.at(state)]->getTransitionSet(); } bool ReadingMachine::hasSplitWordTransitionSet() const @@ -128,37 +113,29 @@ const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const return strategyDefinition; } -Classifier * ReadingMachine::getClassifier() +Classifier * ReadingMachine::getClassifier(const std::string & state) { - return classifier.get(); + return classifiers[state2classifier.at(state)].get(); } void ReadingMachine::saveDicts() const { - classifier->getNN()->saveDicts(path.parent_path()); -} - -void ReadingMachine::loadDicts() -{ - classifier->getNN()->loadDicts(path.parent_path()); -} - -void ReadingMachine::save(const std::string & modelNameTemplate) const -{ - saveDicts(); - - auto pathToClassifier = path.parent_path() / fmt::format(modelNameTemplate, classifier->getName()); - torch::save(classifier->getNN(), pathToClassifier); + for (auto & classifier : classifiers) + classifier->saveDicts(); } void ReadingMachine::saveBest() const { - save(defaultModelFilename); + saveDicts(); + for (auto & classifier : classifiers) + classifier->saveBest(); } void ReadingMachine::saveLast() const { - save(lastModelFilename); + saveDicts(); + for (auto & classifier : classifiers) + classifier->saveLast(); } bool ReadingMachine::isPredicted(const std::string & columnName) const @@ -173,34 +150,47 @@ const std::set<std::string> & ReadingMachine::getPredicted() const void ReadingMachine::trainMode(bool isTrainMode) { - classifier->getNN()->train(isTrainMode); + for (auto & classifier : classifiers) + classifier->getNN()->train(isTrainMode); } void ReadingMachine::setDictsState(Dict::State state) { - classifier->getNN()->setDictsState(state); + for (auto & classifier : classifiers) + classifier->getNN()->setDictsState(state); } -void ReadingMachine::loadLastSaved() +void ReadingMachine::setCountOcc(bool countOcc) { - auto lastSavedModel = util::findFilesByExtension(path.parent_path(), fmt::format(ReadingMachine::lastModelFilename, "")); - if (!lastSavedModel.empty()) - torch::load(classifier->getNN(), lastSavedModel[0]); + for (auto & classifier : classifiers) + classifier->getNN()->setCountOcc(countOcc); } -void ReadingMachine::setCountOcc(bool countOcc) +void ReadingMachine::removeRareDictElements(float rarityThreshold) { - classifier->getNN()->setCountOcc(countOcc); + for (auto & classifier : classifiers) + classifier->getNN()->removeRareDictElements(rarityThreshold); } -void ReadingMachine::removeRareDictElements(float rarityThreshold) +void ReadingMachine::resetClassifiers() { - classifier->getNN()->removeRareDictElements(rarityThreshold); + for (unsigned int i = 0; i < classifiers.size(); i++) + classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train)); +} + +int ReadingMachine::getNbParameters() const +{ + int sum = 0; + + for (auto & classifier : classifiers) + sum += classifier->getNbParameters(); + + return sum; } -void ReadingMachine::resetClassifier() +void ReadingMachine::resetOptimizers() { - classifier.reset(new Classifier(classifierName, path, classifierDefinition)); - loadDicts(); + for (auto & classifier : classifiers) + classifier->resetOptimizer(); } diff --git a/torch_modules/src/DictHolder.cpp b/torch_modules/src/DictHolder.cpp index f712112482477f3bb7b868c48b13f534bfd6e407..934115a32f93f24902d9ccb76be35f012032c58b 100644 --- a/torch_modules/src/DictHolder.cpp +++ b/torch_modules/src/DictHolder.cpp @@ -18,7 +18,9 @@ void DictHolder::saveDict(std::filesystem::path path) void DictHolder::loadDict(std::filesystem::path path) { - dict.reset(new Dict((path / filename()).c_str(), dict->getState())); + auto dictPath = path / filename(); + if (std::filesystem::exists(dictPath)) + dict.reset(new Dict(dictPath.c_str(), dict->getState())); } Dict & DictHolder::getDict() diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 26521a8e14d7f4ddf7fde45ddda011934cad8d71..343b786c47502ac8b472b2ab05f5955165bfa7cb 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -156,7 +156,7 @@ int MacaonTrain::main() try { - ReadingMachine machine(machinePath.string()); + ReadingMachine machine(machinePath.string(), true); BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); @@ -164,14 +164,6 @@ int MacaonTrain::main() Trainer trainer(machine, batchSize, lossFunction); Decoder decoder(machine); - if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) - { - machine.loadDicts(); - machine.getClassifier()->getNN()->registerEmbeddings(); - machine.loadLastSaved(); - machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); - } - float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); auto trainInfos = machinePath.parent_path() / "train.info"; @@ -195,13 +187,6 @@ int MacaonTrain::main() std::fclose(f); } - auto optimizerCheckpoint = machinePath.parent_path() / "checkpoint.optimizer"; - if (std::filesystem::exists(trainInfos)) - { - machine.getClassifier()->resetOptimizer(); - machine.getClassifier()->loadOptimizer(optimizerCheckpoint); - } - for (; currentEpoch < nbEpoch; currentEpoch++) { bool saved = false; @@ -231,14 +216,12 @@ int MacaonTrain::main() { if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) { - machine.resetClassifier(); + machine.resetClassifiers(); machine.trainMode(currentEpoch == 0); - machine.getClassifier()->getNN()->registerEmbeddings(); - machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); - fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); + fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); } - machine.getClassifier()->resetOptimizer(); + machine.resetOptimizers(); } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::Save)) { @@ -290,8 +273,9 @@ int MacaonTrain::main() bestDevScore = devScoreMean; machine.saveBest(); } + machine.saveLast(); - machine.getClassifier()->saveOptimizer(optimizerCheckpoint); + if (printAdvancement) fmt::print(stderr, "\r{:80}\r", ""); std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : ""); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 9c2fad8e5542c07df518fdfa71339846d3095d7f..f161778e464b3f447fa9897ab60899266fc5b68a 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -93,7 +93,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); - machine.getClassifier()->setState(config.getState()); + machine.getClassifier(config.getState())->setState(config.getState()); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); @@ -111,14 +111,15 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); + + auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); std::vector<std::vector<long>> context; try { - context = machine.getClassifier()->getNN()->extractContext(config); + context = machine.getClassifier(config.getState())->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); @@ -126,14 +127,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; - auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); + auto goldTransitions = machine.getTransitionSet(config.getState()).getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()]; - int nbClasses = machine.getTransitionSet().size(); + int nbClasses = machine.getTransitionSet(config.getState()).size(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze(); + auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(); float bestScore = std::numeric_limits<float>::min(); std::vector<int> candidates; @@ -152,7 +153,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p candidates.emplace_back(i); } - transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]); + transition = machine.getTransitionSet(config.getState()).getTransition(candidates[std::rand()%candidates.size()]); } else { @@ -171,7 +172,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p std::vector<int> goldIndexes; for (auto & t : goldTransitions) - goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t)); + goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes); @@ -187,7 +188,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p break; config.setState(movement.first); - machine.getClassifier()->setState(movement.first); + machine.getClassifier(config.getState())->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) @@ -220,20 +221,20 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance for (auto & batch : *loader) { - if (train) - machine.getClassifier()->getOptimizer().zero_grad(); - auto data = std::get<0>(batch); auto labels = std::get<1>(batch); auto state = std::get<2>(batch); - machine.getClassifier()->setState(state); + if (train) + machine.getClassifier(state)->getOptimizer().zero_grad(); + + machine.getClassifier(state)->setState(state); - auto prediction = machine.getClassifier()->getNN()(data); + auto prediction = machine.getClassifier(state)->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); - auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels); + auto loss = machine.getClassifier(state)->getLossMultiplier()*lossFct(prediction, labels); float lossAsFloat = 0.0; try { @@ -246,7 +247,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (train) { loss.backward(); - machine.getClassifier()->getOptimizer().step(); + machine.getClassifier(state)->getOptimizer().step(); } totalNbExamplesProcessed += labels.size(0);