From f799c58fe04a1f5080209ab4237190ce51b43b94 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 4 Jun 2020 16:02:52 +0200 Subject: [PATCH] Fixed modules using values instead of embeddings --- torch_modules/src/NumericColumnModule.cpp | 3 +-- torch_modules/src/UppercaseRateModule.cpp | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 45ebb1a..5f8c8d4 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -42,8 +42,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); - void * dataPtr = context.flatten().data_ptr(); - auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1}); + auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); return myModule->forward(values); } diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 2118745..7f92e05 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -40,8 +40,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); - void * dataPtr = context.flatten().data_ptr(); - auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1}); + auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); return myModule->forward(values); } -- GitLab