diff --git a/reading_machine/src/ReadingMachine.cpp b/reading_machine/src/ReadingMachine.cpp index e1b81693ad0d644b2832270e4776fb08974b9a6c..f5fd3c4c0c063d0b7b5e29744e9d01a4f29d6d20 100644 --- a/reading_machine/src/ReadingMachine.cpp +++ b/reading_machine/src/ReadingMachine.cpp @@ -11,6 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file readFromFile(path); loadDicts(); + trainMode(false); classifier->getNN()->registerEmbeddings(); classifier->getNN()->to(NeuralNetworkImpl::device); diff --git a/torch_modules/src/Submodule.cpp b/torch_modules/src/Submodule.cpp index ea63b99f952ba705bf7a4d9dd6ae7475568d4704..66f645547d876eab18c6e18971404013f67fac31 100644 --- a/torch_modules/src/Submodule.cpp +++ b/torch_modules/src/Submodule.cpp @@ -9,6 +9,8 @@ void Submodule::loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, s { if (path.empty()) return; + if (!is_training()) + return; if (!std::filesystem::exists(path)) util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string())); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 21529bf3b4e29b86a851f69dad48ae225d12d8eb..8e43cd34a07dd683c92153526a50e92db466e0e1 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -207,7 +207,7 @@ int MacaonTrain::main() } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { - machine.setDictsState(Dict::State::Open); + machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); if (!computeDevScore) { @@ -220,6 +220,7 @@ int MacaonTrain::main() if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters)) { machine.resetClassifier(); + machine.trainMode(currentEpoch == 0); machine.getClassifier()->getNN()->registerEmbeddings(); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getClassifier()->getNbParameters())); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 1d363c5c66bdee84237887c2c4f255142790eda0..af1ef2e90dfc3b0e60d6a7bab116f5b46ef5f4d7 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -155,7 +155,6 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance torch::AutoGradMode useGrad(train); machine.trainMode(train); - machine.setDictsState(Dict::State::Closed); auto lossFct = torch::nn::CrossEntropyLoss();