diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index d6101221408b3795260a23e08a1e3c81de3b2deb..03b7f880742df3dd737f555fd96ec1490a6c5a7c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -9,7 +9,9 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem { SubConfig config(goldConfig, goldConfig.getNbLines()); - machine.trainMode(true); + machine.trainMode(false); + machine.setDictsState(Dict::State::Open); + extractExamples(config, debug, dir, epoch, dynamicOracleInterval); trainDataset.reset(new Dataset(dir)); @@ -21,6 +23,8 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); + machine.setDictsState(Dict::State::Closed); + extractExamples(config, debug, dir, epoch, dynamicOracleInterval); devDataset.reset(new Dataset(dir)); @@ -40,7 +44,6 @@ void Trainer::saveExamples(std::vector<torch::Tensor> & contexts, std::vector<to void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval) { torch::AutoGradMode useGrad(false); - machine.setDictsState(Dict::State::Open); int maxNbExamplesPerFile = 250000; int currentExampleIndex = 0; @@ -92,7 +95,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p Transition * transition = nullptr; - if (dynamicOracle and config.getState() != "tokenizer") + if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();