Skip to content
Snippets Groups Projects
Commit 9c793241 authored by Franck Dary's avatar Franck Dary
Browse files

Only transfer tensor to gpu the moment we serve them throught method get of ConfigDataset

parent e45d45e6
No related branches found
No related tags found
No related merge requests found
#include "ConfigDataset.hpp"
#include "NeuralNetwork.hpp"
ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
{
......@@ -24,6 +25,6 @@ torch::optional<size_t> ConfigDataset::size() const
torch::data::Example<> ConfigDataset::get(size_t index)
{
return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
return {data.narrow(0, index*(contextSize+1), contextSize).to(NeuralNetworkImpl::device), data.narrow(0, index*(contextSize+1)+contextSize, 1).to(NeuralNetworkImpl::device)};
}
......@@ -54,14 +54,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
{
context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone());
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong));
gold[0] = goldIndex;
for (auto & element : context)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment