Skip to content
Snippets Groups Projects
Commit 38bfb975 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed discrete features values

parent 3f77dfcb
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ namespace util
//using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>;
using String = std::string;
constexpr float float2longScale = 10000;
constexpr float float2longScale = 100000;
void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
......
......@@ -11,7 +11,11 @@ float util::long2float(long l)
long util::float2long(float f)
{
return f * util::float2longScale;
float res = f * util::float2longScale;
if (f != 0 and res / f != util::float2longScale)
util::myThrow(fmt::format("Float '{}' is too big to be converted to long", f));
return res;
}
int util::printedLength(std::string_view s)
......
......@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1});
}
......@@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config
{util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
}
//TODO : Check if this works
context[firstInputIndex+insertIndex] = res;
context[firstInputIndex+insertIndex] = util::float2long(res);
insertIndex++;
}
}
......
......@@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1});
}
......@@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config
res = 1.0*nbUpper/word.size();
}
//TODO : Check if this works
context[firstInputIndex+insertIndex] = res;
context[firstInputIndex+insertIndex] = util::float2long(res);
insertIndex++;
}
}
void UppercaseRateModuleImpl::registerEmbeddings()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment