Skip to content
Snippets Groups Projects
Commit cdc9ed54 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected wrong device error when using from_blob

parent 5769657d
No related branches found
No related tags found
No related merge requests found
...@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug) ...@@ -48,7 +48,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
elements[index].config.setAppliableTransitions(appliableTransitions); elements[index].config.setAppliableTransitions(appliableTransitions);
auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
......
...@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: ...@@ -93,7 +93,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{ {
auto & classifier = *machine.getClassifier(config.getState()); auto & classifier = *machine.getClassifier(config.getState());
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction); entropy = NeuralNetworkImpl::entropy(prediction);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment