diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 929d8620f530127c313378173d0c423de6e44151..d3157426b9c38268897a42a377e90fbf03aa6632 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -92,8 +92,9 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { - auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); + auto & classifier = *machine.getClassifier(config.getState()); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); + auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates;