Skip to content
Snippets Groups Projects
Commit 38bfb975 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed discrete features values

parent 3f77dfcb
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ namespace util ...@@ -20,7 +20,7 @@ namespace util
//using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>; //using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>;
using String = std::string; using String = std::string;
constexpr float float2longScale = 10000; constexpr float float2longScale = 100000;
void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
......
...@@ -11,7 +11,11 @@ float util::long2float(long l) ...@@ -11,7 +11,11 @@ float util::long2float(long l)
long util::float2long(float f) long util::float2long(float f)
{ {
return f * util::float2longScale; float res = f * util::float2longScale;
if (f != 0 and res / f != util::float2longScale)
util::myThrow(fmt::format("Float '{}' is too big to be converted to long", f));
return res;
} }
int util::printedLength(std::string_view s) int util::printedLength(std::string_view s)
......
...@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st ...@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1); auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1}); return myModule->forward(values).reshape({input.size(0), -1});
} }
...@@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config ...@@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config
{util::myThrow(fmt::format("{} for '{}'", e.what(), value));} {util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
} }
//TODO : Check if this works context[firstInputIndex+insertIndex] = util::float2long(res);
context[firstInputIndex+insertIndex] = res; insertIndex++;
} }
} }
......
...@@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st ...@@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{ {
auto context = input.narrow(1, firstInputIndex, getInputSize()); auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1}); return myModule->forward(values).reshape({input.size(0), -1});
} }
...@@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config ...@@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config
res = 1.0*nbUpper/word.size(); res = 1.0*nbUpper/word.size();
} }
//TODO : Check if this works context[firstInputIndex+insertIndex] = util::float2long(res);
context[firstInputIndex+insertIndex] = res;
insertIndex++; insertIndex++;
} }
} }
void UppercaseRateModuleImpl::registerEmbeddings() void UppercaseRateModuleImpl::registerEmbeddings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment