From 9c793241516148c9310dfa60b0acd75979eb7323 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 2 Apr 2020 10:18:40 +0200
Subject: [PATCH] Only transfer tensor to gpu the moment we serve them throught
 method get of ConfigDataset

---
 torch_modules/src/ConfigDataset.cpp | 3 ++-
 trainer/src/Trainer.cpp             | 4 ++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index 439cbcc..35f1942 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -1,4 +1,5 @@
 #include "ConfigDataset.hpp"
+#include "NeuralNetwork.hpp"
 
 ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
 {
@@ -24,6 +25,6 @@ torch::optional<size_t> ConfigDataset::size() const
 
 torch::data::Example<> ConfigDataset::get(size_t index)
 {
-  return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
+  return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)};
 }
 
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 6ebf18e..ccbad4e 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -54,14 +54,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
     {
       context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
       for (auto & element : context)
-        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
+        contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
     } catch(std::exception & e)
     {
       util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
     }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
-    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
+    auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
     gold[0] = goldIndex;
 
     for (auto & element : context)
-- 
GitLab