From 81fdf35497fbce9dd113c8ad40798f5222f24231 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 4 Jan 2021 08:44:51 +0100
Subject: [PATCH] Added default value for NumericColumnModule

---
 torch_modules/include/NumericColumnModule.hpp | 1 +
 torch_modules/src/NumericColumnModule.cpp     | 7 +++++--
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp
index 26e295a..e3d46ae 100644
--- a/torch_modules/include/NumericColumnModule.hpp
+++ b/torch_modules/include/NumericColumnModule.hpp
@@ -13,6 +13,7 @@ class NumericColumnModuleImpl : public Submodule
   private :
 
   int outSize;
+  float defaultValue;
   std::vector<int> focusedBuffer, focusedStack;
   std::shared_ptr<MyModule> myModule{nullptr};
   std::string column;
diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp
index b535ded..8fb0ab3 100644
--- a/torch_modules/src/NumericColumnModule.cpp
+++ b/torch_modules/src/NumericColumnModule.cpp
@@ -4,7 +4,7 @@
 NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition)
 {
   setName(name);
-  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
+  std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)DefaultValue\\{(.*)\\}(?:(?:\\s|\\t)*)");
   if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
         {
           try
@@ -28,6 +28,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
 
             int outSize = std::stoi(sm.str(6));
 
+            defaultValue = std::stoi(sm.str(7));
+
             if (subModuleType == "LSTM")
               myModule = register_module("myModule", LSTM(1, outSize, options));
             else if (subModuleType == "GRU")
@@ -76,7 +78,8 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
     {
       double res = 0.0;
       if (index >= 0)
-        res = std::stof(config.getAsFeature(column, index).get());
+        try {res = std::stof(config.getAsFeature(column, index).get());}
+        catch (std::exception &) {res = defaultValue;}
 
       contextElement.emplace_back(0);
       std::memcpy(&contextElement.back(), &res, sizeof res);
-- 
GitLab