From 81fdf35497fbce9dd113c8ad40798f5222f24231 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 4 Jan 2021 08:44:51 +0100 Subject: [PATCH] Added default value for NumericColumnModule --- torch_modules/include/NumericColumnModule.hpp | 1 + torch_modules/src/NumericColumnModule.cpp | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 26e295a..e3d46ae 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -13,6 +13,7 @@ class NumericColumnModuleImpl : public Submodule private : int outSize; + float defaultValue; std::vector<int> focusedBuffer, focusedStack; std::shared_ptr<MyModule> myModule{nullptr}; std::string column; diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index b535ded..8fb0ab3 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -4,7 +4,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition) { setName(name); - std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)DefaultValue\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try @@ -28,6 +28,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st int outSize = std::stoi(sm.str(6)); + defaultValue = std::stoi(sm.str(7)); + if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") @@ -76,7 +78,8 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont { double res = 0.0; if (index >= 0) - res = std::stof(config.getAsFeature(column, index).get()); + try {res = std::stof(config.getAsFeature(column, index).get());} + catch (std::exception &) {res = defaultValue;} contextElement.emplace_back(0); std::memcpy(&contextElement.back(), &res, sizeof res); -- GitLab