From 95131d6d8186067630bb3457395ad04071964833 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 10 Jun 2020 13:36:31 +0200 Subject: [PATCH] Added DistanceModule --- reading_machine/include/Config.hpp | 1 + reading_machine/src/Config.cpp | 26 +++++ torch_modules/include/DistanceModule.hpp | 34 +++++++ torch_modules/include/ModularNetwork.hpp | 1 + torch_modules/src/DistanceModule.cpp | 115 +++++++++++++++++++++++ torch_modules/src/ModularNetwork.cpp | 2 + 6 files changed, 179 insertions(+) create mode 100644 torch_modules/include/DistanceModule.hpp create mode 100644 torch_modules/src/DistanceModule.cpp diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index faa594e..71fb1ee 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -98,6 +98,7 @@ class Config std::size_t & getStackRef(int relativeIndex); long getRelativeWordIndex(int relativeIndex) const; + long getRelativeDistance(int fromIndex, int toIndex) const; public : diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 4e320b9..c967a4f 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -675,6 +675,32 @@ long Config::getRelativeWordIndex(int relativeIndex) const return -1; } +long Config::getRelativeDistance(int fromIndex, int toIndex) const +{ + if (toIndex < fromIndex) + { + for (int index = fromIndex, counter = 0; has(0,index,0); --index) + if (!isCommentPredicted(index)) + { + if (index == toIndex) + return counter; + --counter; + } + } + else + { + for (int index = fromIndex, counter = 0; has(0,index,0); ++index) + if (!isCommentPredicted(index)) + { + if (index == toIndex) + return counter; + ++counter; + } + } + + return 0; +} + long Config::getRelativeWordIndex(Object object, int relativeIndex) const { if (object == Object::Buffer) diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp new file mode 100644 index 0000000..b6e22d8 --- /dev/null +++ b/torch_modules/include/DistanceModule.hpp @@ -0,0 +1,34 @@ +#ifndef DISTANCEMODULE__H +#define DISTANCEMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" +#include "Concat.hpp" + +class DistanceModuleImpl : public Submodule +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; + std::vector<int> fromBuffer, fromStack; + std::vector<int> toBuffer, toStack; + int threshold; + int inSize; + + public : + + DistanceModuleImpl(std::string name, const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings(std::filesystem::path pretrained) override; +}; +TORCH_MODULE(DistanceModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index a6a6c3e..7e98302 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -12,6 +12,7 @@ #include "UppercaseRateModule.hpp" #include "NumericColumnModule.hpp" #include "HistoryModule.hpp" +#include "DistanceModule.hpp" #include "MLP.hpp" class ModularNetworkImpl : public NeuralNetworkImpl diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp new file mode 100644 index 0000000..50deea0 --- /dev/null +++ b/torch_modules/src/DistanceModule.cpp @@ -0,0 +1,115 @@ +#include "DistanceModule.hpp" + +DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & definition) +{ + setName(name); + std::regex regex("(?:(?:\\s|\\t)*)FromBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)FromStack\\{(.*)\\}(?:(?:\\s|\\t)*)ToBuffer\\{(.*)\\}(?:(?:\\s|\\t)*)ToStack\\{(.*)\\}(?:(?:\\s|\\t)*)Threshold\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + for (auto & index : util::split(sm.str(1), ' ')) + fromBuffer.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(2), ' ')) + fromStack.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(3), ' ')) + toBuffer.emplace_back(std::stoi(index)); + + for (auto & index : util::split(sm.str(4), ' ')) + toStack.emplace_back(std::stoi(index)); + + threshold = std::stoi(sm.str(5)); + + auto subModuleType = sm.str(6); + auto subModuleArguments = util::split(sm.str(7), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + inSize = std::stoi(sm.str(8)); + int outSize = std::stoi(sm.str(9)); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor DistanceModuleImpl::forward(torch::Tensor input) +{ + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); +} + +std::size_t DistanceModuleImpl::getOutputSize() +{ + return myModule->getOutputSize(getInputSize()); +} + +std::size_t DistanceModuleImpl::getInputSize() +{ + return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size()); +} + +void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + auto & dict = getDict(); + std::vector<long> fromIndexes, toIndexes; + + for (int index : fromBuffer) + fromIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : fromStack) + if (config.hasStack(index)) + fromIndexes.emplace_back(config.getStack(index)); + else + fromIndexes.emplace_back(-1); + + for (int index : toBuffer) + toIndexes.emplace_back(config.getRelativeWordIndex(index)); + + for (int index : toStack) + if (config.hasStack(index)) + toIndexes.emplace_back(config.getStack(index)); + else + toIndexes.emplace_back(-1); + + for (auto & contextElement : context) + { + for (auto from : fromIndexes) + for (auto to : toIndexes) + { + if (from == -1 or to == -1) + { + contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + continue; + } + + long dist = std::abs(config.getRelativeDistance(from, to)); + + if (dist <= threshold) + contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}", dist))); + else + contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr)); + } + } +} + +void DistanceModuleImpl::registerEmbeddings(std::filesystem::path path) +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, path); +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index f9707a1..22cdb3a 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -45,6 +45,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, SplitTransModule(nameH, Config::maxNbAppliableSplitTransitions, splited.second))); else if (splited.first == "AppliableTrans") modules.emplace_back(register_module(name, AppliableTransModule(nameH, maxNbOutputs))); + else if (splited.first == "Distance") + modules.emplace_back(register_module(name, DistanceModule(nameH, splited.second))); else if (splited.first == "DepthLayerTree") modules.emplace_back(register_module(name, DepthLayerTreeEmbeddingModule(nameH, splited.second))); else if (splited.first == "MLP") -- GitLab