diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 123c063aafb174bde6eb543016422b3614c9f3fc..63f7a3b2e00d7150af93fc7906d47f4f72d3df0e 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -2,6 +2,7 @@ #define CONTEXTMODULE__H #include <torch/torch.h> +#include <optional> #include "Submodule.hpp" #include "MyModule.hpp" #include "GRU.hpp" @@ -16,8 +17,7 @@ class ContextModuleImpl : public Submodule std::shared_ptr<MyModule> myModule{nullptr}; std::vector<std::string> columns; std::vector<std::function<std::string(const std::string &)>> functions; - std::vector<int> bufferContext; - std::vector<int> stackContext; + std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; int inSize; std::filesystem::path w2vFile; diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 75a23b19a10bac91e0911ddde5d8057b134c11c8..a6e1c85974c7b97a60cc6bd0980dc3125163593f 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -4,18 +4,20 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin { setName(name); - std::regex regex("(?:(?:\\s|\\t)*)Buffer\\{(.*)\\}(?:(?:\\s|\\t)*)Stack\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)"); + std::regex regex("(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)"); if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) { try { - for (auto & index : util::split(sm.str(1), ' ')) - bufferContext.emplace_back(std::stoi(index)); - - for (auto & index : util::split(sm.str(2), ' ')) - stackContext.emplace_back(std::stoi(index)); + for (auto & target : util::split(sm.str(1), ' ')) + { + auto splited = util::split(target, '.'); + if (splited.size() != 2 and splited.size() != 3) + util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target)); + targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>())); + } - auto funcColumns = util::split(sm.str(3), ' '); + auto funcColumns = util::split(sm.str(2), ' '); columns.clear(); for (auto & funcCol : funcColumns) { @@ -23,8 +25,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin columns.emplace_back(util::split(funcCol, ':').back()); } - auto subModuleType = sm.str(4); - auto subModuleArguments = util::split(sm.str(5), ' '); + auto subModuleType = sm.str(3); + auto subModuleArguments = util::split(sm.str(4), ' '); auto options = MyModule::ModuleOptions(true) .bidirectional(std::stoi(subModuleArguments[0])) @@ -32,8 +34,8 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin .dropout(std::stof(subModuleArguments[2])) .complete(std::stoi(subModuleArguments[3])); - inSize = std::stoi(sm.str(6)); - int outSize = std::stoi(sm.str(7)); + inSize = std::stoi(sm.str(5)); + int outSize = std::stoi(sm.str(6)); if (subModuleType == "LSTM") myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); @@ -44,7 +46,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); - w2vFile = sm.str(8); + w2vFile = sm.str(7); if (!w2vFile.empty()) { @@ -60,12 +62,12 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin std::size_t ContextModuleImpl::getOutputSize() { - return myModule->getOutputSize(bufferContext.size()+stackContext.size()); + return myModule->getOutputSize(targets.size()); } std::size_t ContextModuleImpl::getInputSize() { - return columns.size()*(bufferContext.size()+stackContext.size()); + return columns.size()*(targets.size()); } void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) @@ -73,12 +75,24 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c auto & dict = getDict(); std::vector<long> contextIndexes; - for (int index : bufferContext) - contextIndexes.emplace_back(config.getRelativeWordIndex(index)); - - for (int index : stackContext) - if (config.hasStack(index)) - contextIndexes.emplace_back(config.getStack(index)); + for (auto & target : targets) + if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target))) + { + int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target)); + if (!std::get<2>(target)) + contextIndexes.emplace_back(baseIndex); + else + { + int childIndex = *std::get<2>(target); + auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); + if (childIndex >= 0 and childIndex < (int)childs.size()) + contextIndexes.emplace_back(std::stoi(childs[childIndex])); + else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0) + contextIndexes.emplace_back(std::stoi(childs[childs.size()+childIndex])); + else + contextIndexes.emplace_back(-1); + } + } else contextIndexes.emplace_back(-1);