From 38bfb9754d165af7f43d0cd7a8df83ce718505bc Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 16 Mar 2021 14:19:01 +0100 Subject: [PATCH] Fixed discrete features values --- common/include/util.hpp | 2 +- common/src/util.cpp | 6 +++++- torch_modules/src/NumericColumnModule.cpp | 6 +++--- torch_modules/src/UppercaseRateModule.cpp | 6 ++---- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/common/include/util.hpp b/common/include/util.hpp index 3b7b6eb..b555e55 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -20,7 +20,7 @@ namespace util //using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>; using String = std::string; -constexpr float float2longScale = 10000; +constexpr float float2longScale = 100000; void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); diff --git a/common/src/util.cpp b/common/src/util.cpp index a30e4a3..e66bf67 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -11,7 +11,11 @@ float util::long2float(long l) long util::float2long(float f) { - return f * util::float2longScale; + float res = f * util::float2longScale; + if (f != 0 and res / f != util::float2longScale) + util::myThrow(fmt::format("Float '{}' is too big to be converted to long", f)); + + return res; } int util::printedLength(std::string_view s) diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 57c33f6..a5001e7 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); - auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1); + auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale; return myModule->forward(values).reshape({input.size(0), -1}); } @@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config {util::myThrow(fmt::format("{} for '{}'", e.what(), value));} } - //TODO : Check if this works - context[firstInputIndex+insertIndex] = res; + context[firstInputIndex+insertIndex] = util::float2long(res); + insertIndex++; } } diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 5f6de21..ff6e0ac 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); - auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); + auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale; return myModule->forward(values).reshape({input.size(0), -1}); } @@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config res = 1.0*nbUpper/word.size(); } - //TODO : Check if this works - context[firstInputIndex+insertIndex] = res; + context[firstInputIndex+insertIndex] = util::float2long(res); insertIndex++; } - } void UppercaseRateModuleImpl::registerEmbeddings() -- GitLab