From cdc9ed544399c72b0ee5f84ee4817fd7de415a4f Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 2 Mar 2021 17:50:49 +0100
Subject: [PATCH] Corrected wrong device error when using from_blob

---
 decoder/src/Beam.cpp    | 2 +-
 trainer/src/Trainer.cpp | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp
index 1b0aafa..e39593c 100644
--- a/decoder/src/Beam.cpp
+++ b/decoder/src/Beam.cpp
@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
     elements[index].config.setAppliableTransitions(appliableTransitions);
 
     auto context = classifier.getNN()->extractContext(elements[index].config).back();
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
 
     auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
     float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index a85fff0..7f1aaec 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
       if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
       {
         auto & classifier = *machine.getClassifier(config.getState());
-        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+        auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
         auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
         entropy  = NeuralNetworkImpl::entropy(prediction);
     
-- 
GitLab