diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index 26e295a50fa88248817385e293a17c5b6c4864fb..e3d46aece017226572fbc469e53f85dfbaddc518 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -13,6 +13,7 @@ class NumericColumnModuleImpl : public Submodule private : int outSize; + float defaultValue; std::vector<int> focusedBuffer, focusedStack; std::shared_ptr<MyModule> myModule{nullptr}; std::string column; diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index b535ded6b006b9249ddc1574c3aae816bb36a050..8fb0ab33a0cf074810f9278750d737f97a042fad 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -4,7 +4,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition) { setName(name); - std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)DefaultValue\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try @@ -28,6 +28,8 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st int outSize = std::stoi(sm.str(6)); + defaultValue = std::stoi(sm.str(7)); + if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(1, outSize, options)); else if (subModuleType == "GRU") @@ -76,7 +78,8 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont { double res = 0.0; if (index >= 0) - res = std::stof(config.getAsFeature(column, index).get()); + try {res = std::stof(config.getAsFeature(column, index).get());} + catch (std::exception &) {res = defaultValue;} contextElement.emplace_back(0); std::memcpy(&contextElement.back(), &res, sizeof res);