Commit 38bfb975 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed discrete features values

parent 3f77dfcb
......@@ -20,7 +20,7 @@ namespace util
//using String = boost::flyweights::flyweight<std::string,boost::flyweights::no_tracking>;
using String = std::string;
constexpr float float2longScale = 10000;
constexpr float float2longScale = 100000;
void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
......
......@@ -11,7 +11,11 @@ float util::long2float(long l)
long util::float2long(float f)
{
return f * util::float2longScale;
float res = f * util::float2longScale;
if (f != 0 and res / f != util::float2longScale)
util::myThrow(fmt::format("Float '{}' is too big to be converted to long", f));
return res;
}
int util::printedLength(std::string_view s)
......
......@@ -46,7 +46,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1);
auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1});
}
......@@ -85,8 +85,8 @@ void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config
{util::myThrow(fmt::format("{} for '{}'", e.what(), value));}
}
//TODO : Check if this works
context[firstInputIndex+insertIndex] = res;
context[firstInputIndex+insertIndex] = util::float2long(res);
insertIndex++;
}
}
......
......@@ -42,7 +42,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input)
{
auto context = input.narrow(1, firstInputIndex, getInputSize());
auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone();
auto values = context.to(torch::kFloat).unsqueeze(-1) / util::float2longScale;
return myModule->forward(values).reshape({input.size(0), -1});
}
......@@ -84,11 +84,9 @@ void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config
res = 1.0*nbUpper/word.size();
}
//TODO : Check if this works
context[firstInputIndex+insertIndex] = res;
context[firstInputIndex+insertIndex] = util::float2long(res);
insertIndex++;
}
}
void UppercaseRateModuleImpl::registerEmbeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment