From 9108fc2ba64ccd280bb1ca2cc8989084c54c535b Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 6 May 2020 21:06:45 +0200 Subject: [PATCH] Do not compute grad with newer modules --- torch_modules/src/NumericColumnModule.cpp | 2 +- torch_modules/src/UppercaseRateModule.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 60ca6e4..0ea72c4 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -43,7 +43,7 @@ torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); void * dataPtr = context.flatten().data_ptr(); - auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::kDouble).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1}); + auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false)).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1}); return myModule->forward(values); } diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 70f1bea..67d2b26 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -41,7 +41,7 @@ torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); void * dataPtr = context.flatten().data_ptr(); - auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::kDouble).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1}); + auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false)).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1}); return myModule->forward(values); } -- GitLab