From 9df2da0fa5ec7653f62fbc9dd37145d661481cb0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 1 Mar 2021 10:36:54 +0100
Subject: [PATCH] Added case for regression in extractExamples

---
 trainer/src/Trainer.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 929d862..d315742 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -92,8 +92,9 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
         
       if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
       {
-        auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
-        auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
+        auto & classifier = *machine.getClassifier(config.getState());
+        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
+        auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
         std::vector<int> candidates;
-- 
GitLab