Skip to content
Snippets Groups Projects
Commit f799c58f authored by Franck Dary's avatar Franck Dary
Browse files

Fixed modules using values instead of embeddings

parent e00aa24d
No related branches found
No related tags found
No related merge requests found
...@@ -42,8 +42,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st ...@@ -42,8 +42,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = input.narrow(1, firstInputIndex, getInputSize());
void * dataPtr = context.flatten().data_ptr(); auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1});
return myModule->forward(values); return myModule->forward(values);
} }
......
...@@ -40,8 +40,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st ...@@ -40,8 +40,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = input.narrow(1, firstInputIndex, getInputSize());
void * dataPtr = context.flatten().data_ptr(); auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).clone().to(torch::kFloat).view({(long)context.size(0), (long)context.size(1), 1});
return myModule->forward(values); return myModule->forward(values);
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment