From cdc9ed544399c72b0ee5f84ee4817fd7de415a4f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 2 Mar 2021 17:50:49 +0100 Subject: [PATCH] Corrected wrong device error when using from_blob --- decoder/src/Beam.cpp | 2 +- trainer/src/Trainer.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 1b0aafa..e39593c 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug) elements[index].config.setAppliableTransitions(appliableTransitions); auto context = classifier.getNN()->extractContext(elements[index].config).back(); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index a85fff0..7f1aaec 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto & classifier = *machine.getClassifier(config.getState()); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); -- GitLab