From 9c793241516148c9310dfa60b0acd75979eb7323 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 2 Apr 2020 10:18:40 +0200 Subject: [PATCH] Only transfer tensor to gpu the moment we serve them throught method get of ConfigDataset --- torch_modules/src/ConfigDataset.cpp | 3 ++- trainer/src/Trainer.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index 439cbcc..35f1942 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -1,4 +1,5 @@ #include "ConfigDataset.hpp" +#include "NeuralNetwork.hpp" ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes) { @@ -24,6 +25,6 @@ torch::optional<size_t> ConfigDataset::size() const torch::data::Example<> ConfigDataset::get(size_t index) { - return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)}; + return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)}; } diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 6ebf18e..ccbad4e 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -54,14 +54,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: { context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); for (auto & element : context) - contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device)); + contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndex; for (auto & element : context) -- GitLab