diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 70c159c046eb0cb0c7306869809f4355f3a988f4..7e721b923cadb307ff86016792b6573a0df6cbf2 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -9,6 +9,7 @@ #include "DepthLayerTreeEmbeddingModule.hpp" #include "StateNameModule.hpp" #include "UppercaseRateModule.hpp" +#include "NumericColumnModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..baa2bc8eb9919ec130e88bf6971d3e8a9e5b518b --- /dev/null +++ b/torch_modules/include/NumericColumnModule.hpp @@ -0,0 +1,31 @@ +#ifndef NUMERICCOLUMNMODULE__H +#define NUMERICCOLUMNMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" + +class NumericColumnModuleImpl : public Submodule +{ + private : + + int outSize; + std::vector<int> focusedBuffer, focusedStack; + std::shared_ptr<MyModule> myModule{nullptr}; + std::string column; + + public : + + NumericColumnModuleImpl(std::string name, const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(NumericColumnModule); + +#endif + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 5edf1b42be0cef57c0fbc58c32691cc31a0a83de..b4e23cc1a4effcce03ca6f261bfaacc030512c8c 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -27,6 +27,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, ContextModule(nameH, splited.second))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); + else if (splited.first == "NumericColumn") + modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second))); else if (splited.first == "UppercaseRate") modules.emplace_back(register_module(name, UppercaseRateModule(nameH, splited.second))); else if (splited.first == "Focused") diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..60ca6e44dd8cdda83c034dcb1c97c58a02e6fde0 --- /dev/null +++ b/torch_modules/src/NumericColumnModule.cpp @@ -0,0 +1,88 @@ +#include "NumericColumnModule.hpp" +#include "NeuralNetwork.hpp" + +NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::string & definition) +{ + setName(name); + std::regex regex("(?:(?:\\s|\\t)*)Column\\{(.*)\\}(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + column = sm.str(1); + + for (auto & index : util::split(sm.str(2), ' ')) + focusedBuffer.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(3), ' ')) + focusedStack.emplace_back(std::stoi(index)); + + auto subModuleType = sm.str(4); + auto subModuleArguments = util::split(sm.str(5), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + int outSize = std::stoi(sm.str(6)); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(1, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(1, outSize, options)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) +{ + auto context = input.narrow(1, firstInputIndex, getInputSize()); + void * dataPtr = context.flatten().data_ptr(); + auto values = torch::from_blob(dataPtr, {(long)(context.size(0)*getInputSize())}, torch::kDouble).clone().to(torch::kFloat).to(NeuralNetworkImpl::device).view({(long)context.size(0), (long)context.size(1), 1}); + return myModule->forward(values); +} + +std::size_t NumericColumnModuleImpl::getOutputSize() +{ + return myModule->getOutputSize(getInputSize()); +} + +std::size_t NumericColumnModuleImpl::getInputSize() +{ + return focusedBuffer.size() + focusedStack.size(); +} + +void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + std::vector<long> focusedIndexes; + + for (int index : focusedBuffer) + focusedIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : focusedStack) + if (config.hasStack(index)) + focusedIndexes.emplace_back(config.getStack(index)); + else + focusedIndexes.emplace_back(-1); + + for (auto & contextElement : context) + for (auto index : focusedIndexes) + { + double res = 0.0; + if (index >= 0) + res = std::stof(config.getAsFeature(column, index).get()); + + contextElement.emplace_back(0); + std::memcpy(&contextElement.back(), &res, sizeof res); + } +} + +void NumericColumnModuleImpl::registerEmbeddings() +{ +} +