diff --git a/torch_modules/src/SplitTransLSTM.cpp b/torch_modules/src/SplitTransLSTM.cpp index 283358c8483ee83952704524046e05a787e21217..a83894abf93efda1ed0124fe58057e3ce042be06 100644 --- a/torch_modules/src/SplitTransLSTM.cpp +++ b/torch_modules/src/SplitTransLSTM.cpp @@ -24,10 +24,11 @@ std::size_t SplitTransLSTMImpl::getInputSize() void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const { auto & splitTransitions = config.getAppliableSplitTransitions(); - for (int i = 0; i < maxNbTrans; i++) - if (i < (int)splitTransitions.size()) - context.back().emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); - else - context.back().emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + for (auto & contextElement : context) + for (int i = 0; i < maxNbTrans; i++) + if (i < (int)splitTransitions.size()) + contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName())); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 713cd4fc1b6c8097279ec8e934839be9a0da7e0b..a69b9b9a3a0f926b1117ab02929a43436e67b2f6 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -22,7 +22,6 @@ class Trainer std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; int batchSize; - int nbExamples{0}; private : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 071abf6689adf93f92f9dae8766794c6dfaced9d..16918e25469d7396c05d1067bc71d0bff7063b2f 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -9,11 +9,10 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem { SubConfig config(goldConfig, goldConfig.getNbLines()); + machine.trainMode(true); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); - nbExamples = trainDataset->size().value(); - dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); if (optimizer.get() == nullptr) @@ -24,6 +23,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys { SubConfig config(goldConfig, goldConfig.getNbLines()); + machine.trainMode(false); extractExamples(config, debug, dir, epoch, dynamicOracleInterval); devDataset.reset(new Dataset(dir)); @@ -43,7 +43,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); - machine.trainMode(false); machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000;