From 219be1d73a9f97d2feb02976d980857c71259eff Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 Apr 2020 23:05:01 +0200 Subject: [PATCH] Fixed unknownValueThreshold usage --- torch_modules/src/SplitTransLSTM.cpp | 11 ++++++----- trainer/include/Trainer.hpp | 1 - trainer/src/Trainer.cpp | 5 ++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp index 283358c..a83894a 100644 --- a/torch_modules/src/SplitTransLSTM.cpp +++ b/torch_modules/src/SplitTransLSTM.cpp @@ -24,10 +24,11 @@ std::size_t SplitTransLSTMImpl::getInputSize() void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { auto & splitTransitions = config.getAppliableSplitTransitions(); - for (int i = 0; i < maxNbTrans; i++) - if (i < (int)splitTransitions.size()) - context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); - else - context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & contextElement : context) + for (int i = 0; i < maxNbTrans; i++) + if (i < (int)splitTransitions.size()) + contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 713cd4f..a69b9b9 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -22,7 +22,6 @@ class Trainer std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize; - int nbExamples{0}; private : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 071abf6..16918e2 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -9,11 +9,10 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem { SubConfig config(goldConfig, goldConfig.getNbLines()); + machine.trainMode(true); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); - nbExamples = trainDataset->size().value(); - dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); if (optimizer.get() == nullptr) @@ -24,6 +23,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys { SubConfig config(goldConfig, goldConfig.getNbLines()); + machine.trainMode(false); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); devDataset.reset(new Dataset(dir)); @@ -43,7 +43,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); - machine.trainMode(false); machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000; -- GitLab