From 38bfb9754d165af7f43d0cd7a8df83ce718505bc Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 16 Mar 2021 14:19:01 +0100
Subject: [PATCH] Fixed discrete features values

---
 common/include/util.hpp                   | 2 +-
 common/src/util.cpp                       | 6 +++++-
 torch_modules/src/NumericColumnModule.cpp | 6 +++---
 torch_modules/src/UppercaseRateModule.cpp | 6 ++----
 4 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/common/include/util.hpp b/common/include/util.hpp
index 3b7b6eb..b555e55 100644
--- a/common/include/util.hpp
+++ b/common/include/util.hpp
@@ -20,7 +20,7 @@ namespace util
 //using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>;
 using String = std::string;
 
-constexpr float float2longScale = 10000;
+constexpr float float2longScale = 100000;
 
 void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
 void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
diff --git a/common/src/util.cpp b/common/src/util.cpp
index a30e4a3..e66bf67 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -11,7 +11,11 @@ float util::long2float(long l)
 
 long util::float2long(float f)
 {
-  return f * util::float2longScale;
+  float res = f * util::float2longScale;
+  if (f != 0 and res / f != util::float2longScale)
+    util::myThrow(fmt::format("Float '{}' is too big to be converted to long", f));
+  
+  return res;
 }
 
 int util::printedLength(std::string_view s)
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index 57c33f6..a5001e7 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
 torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
-  auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
+  auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
   return myModule->forward(values).reshape({input.size(0), -1});
 }
 
@@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config
         {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
     }
 
-    //TODO : Check if this works
-    context[firstInputIndex+insertIndex] = res;
+    context[firstInputIndex+insertIndex] = util::float2long(res);
+    insertIndex++;
   }
 }
 
diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp
index 5f6de21..ff6e0ac 100644
--- a/torch_modules/src/UppercaseRateModule.cpp
+++ b/torch_modules/src/UppercaseRateModule.cpp
@@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
 torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
 {
   auto context = input.narrow(1, firstInputIndex, getInputSize());
-  auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
+  auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
   return myModule->forward(values).reshape({input.size(0), -1});
 }
 
@@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config
         res = 1.0*nbUpper/word.size();
     }
 
-    //TODO : Check if this works
-    context[firstInputIndex+insertIndex] = res;
+    context[firstInputIndex+insertIndex] = util::float2long(res);
     insertIndex++;
   }
-
 }
 
 void UppercaseRateModuleImpl::registerEmbeddings()
-- 
GitLab