diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 0ea72c4a617acb494a49f721123aba364fc81ede..ac488db7601eb550495aafd5cf658970398e8eb6 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -43,7 +43,7 @@ torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
   void * dataPtr = context.flatten().data_ptr();
-  auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false)).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1});
+  auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1});
   return myModule->forward(values);
 }
 
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 67d2b26699cb8f759aa36e58736b94a272ac4519..b2ddf61deb11095b47390e5331d4de80c419a32b 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -41,7 +41,7 @@ torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
   void * dataPtr = context.flatten().data_ptr();
-  auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false)).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1});
+  auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1});
   return myModule->forward(values);
 }