From 8d4d9ab5c05f0c4261cb3c61190027a7bd157445 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 18 Jun 2020 22:03:53 +0200 Subject: [PATCH] Make sure nothing is added to the dict mid training --- reading_machine/src/ReadingMachine.cpp | 1 + torch_modules/src/Submodule.cpp | 2 ++ trainer/src/MacaonTrain.cpp | 3 ++- trainer/src/Trainer.cpp | 1 - 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index e1b8169..f5fd3c4 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -11,6 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file readFromFile(path); loadDicts(); + trainMode(false); classifier->getNN()->registerEmbeddings(); classifier->getNN()->to(NeuralNetworkImpl::device); diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index ea63b99..66f6455 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -9,6 +9,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s { if (path.empty()) return; + if (!is_training()) + return; if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 21529bf..8e43cd3 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -207,7 +207,7 @@ int MacaonTrain::main() } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { - machine.setDictsState(Dict::State::Open); + machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); if (!computeDevScore) { @@ -220,6 +220,7 @@ int MacaonTrain::main() if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) { machine.resetClassifier(); + machine.trainMode(currentEpoch == 0); machine.getClassifier()->getNN()->registerEmbeddings(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 1d363c5..af1ef2e 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -155,7 +155,6 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance torch::AutoGradMode useGrad(train); machine.trainMode(train); - machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss(); -- GitLab