From 6d627fa107b6fc8407934b46539c00a921113228 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 16 Apr 2021 11:52:57 +0200 Subject: [PATCH] Added module HistoryMine --- reading_machine/include/Config.hpp | 3 + reading_machine/src/Config.cpp | 16 +++++ torch_modules/include/HistoryMineModule.hpp | 34 ++++++++++ torch_modules/include/ModularNetwork.hpp | 1 + torch_modules/src/HistoryMineModule.cpp | 74 +++++++++++++++++++++ torch_modules/src/ModularNetwork.cpp | 2 + 6 files changed, 130 insertions(+) create mode 100644 torch_modules/include/HistoryMineModule.hpp create mode 100644 torch_modules/src/HistoryMineModule.cpp diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 49eff0d..ed0a29d 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -54,6 +54,7 @@ class Config std::size_t currentSentenceStartRawInput{0}; util::String state{"NONE"}; boost::circular_buffer<util::String> history{10}; + std::map<std::string, boost::circular_buffer<util::String>> stateHistory; boost::circular_buffer<std::size_t> stack{50}; float chosenActionScore{0.0}; std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName}; @@ -142,9 +143,11 @@ class Config long getRelativeWordIndex(Object object, int relativeIndex) const; bool hasRelativeWordIndex(Object object, int relativeIndex) const; const util::String & getHistory(int relativeIndex) const; + const util::String & getHistoryState(int relativeIndex) const; std::size_t getStack(int relativeIndex) const; std::size_t getStackSize() const; bool hasHistory(int relativeIndex) const; + bool hasHistoryState(int relativeIndex) const; bool hasStack(int relativeIndex) const; util::String getState() const; void setState(const std::string state); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index bed8601..69ae400 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -18,6 +18,7 @@ Config::Config(const Config & other) this->characterIndex = other.characterIndex; this->state = other.state; this->history = other.history; + this->stateHistory = other.stateHistory; this->stack = other.stack; this->extraColumns = other.extraColumns; this->mcd = other.mcd; @@ -361,7 +362,10 @@ Config::ConstValueIterator Config::getConstIterator(int colIndex, int lineIndex, void Config::addToHistory(const std::string & transition) { + if (stateHistory.count(state) == 0) + stateHistory.insert({state, boost::circular_buffer<util::String>(10)}); history.push_back(util::String(transition)); + stateHistory.at(state).push_back(util::String(transition)); } void Config::addToStack(std::size_t index) @@ -491,6 +495,11 @@ const util::String & Config::getHistory(int relativeIndex) const return history[history.size()-1-relativeIndex]; } +const util::String & Config::getHistoryState(int relativeIndex) const +{ + return stateHistory.at(state)[stateHistory.at(state).size()-1-relativeIndex]; +} + std::size_t Config::getStack(int relativeIndex) const { if (relativeIndex == -1) @@ -508,6 +517,13 @@ bool Config::hasHistory(int relativeIndex) const return relativeIndex >= 0 && relativeIndex < (int)history.size(); } +bool Config::hasHistoryState(int relativeIndex) const +{ + if (stateHistory.count(state) == 0) + return false; + return relativeIndex >= 0 && relativeIndex < (int)stateHistory.at(state).size(); +} + bool Config::hasStack(int relativeIndex) const { if (relativeIndex == -1) diff --git a/torch_modules/include/HistoryMineModule.hpp b/torch_modules/include/HistoryMineModule.hpp new file mode 100644 index 0000000..7f6afd6 --- /dev/null +++ b/torch_modules/include/HistoryMineModule.hpp @@ -0,0 +1,34 @@ +#ifndef HISTORYMINEMODULE__H +#define HISTORYMINEMODULE__H + +#include <torch/torch.h> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "LSTM.hpp" +#include "GRU.hpp" +#include "CNN.hpp" +#include "Concat.hpp" +#include "WordEmbeddings.hpp" + +class HistoryMineModuleImpl : public Submodule +{ + private : + + WordEmbeddings wordEmbeddings{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; + int maxNbElements; + int inSize; + + public : + + HistoryMineModuleImpl(std::string name, const std::string & definition); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(torch::Tensor & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(HistoryMineModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 31685e2..e2d4643 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -13,6 +13,7 @@ #include "UppercaseRateModule.hpp" #include "NumericColumnModule.hpp" #include "HistoryModule.hpp" +#include "HistoryMineModule.hpp" #include "DistanceModule.hpp" #include "MLP.hpp" diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp new file mode 100644 index 0000000..cf8338d --- /dev/null +++ b/torch_modules/src/HistoryMineModule.cpp @@ -0,0 +1,74 @@ +#include "HistoryMineModule.hpp" + +HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string & definition) +{ + setName(name); + std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + maxNbElements = std::stoi(sm.str(1)); + + auto subModuleType = sm.str(2); + auto subModuleArguments = util::split(sm.str(3), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + inSize = std::stoi(sm.str(4)); + int outSize = std::stoi(sm.str(5)); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(inSize, outSize, options)); + else if (subModuleType == "CNN") + myModule = register_module("myModule", CNN(inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +torch::Tensor HistoryMineModuleImpl::forward(torch::Tensor input) +{ + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1}); +} + +std::size_t HistoryMineModuleImpl::getOutputSize() +{ + return myModule->getOutputSize(maxNbElements); +} + +std::size_t HistoryMineModuleImpl::getInputSize() +{ + return maxNbElements; +} + +void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & config) +{ + auto & dict = getDict(); + + std::string prefix = "HISTORYMINE"; + + for (int i = 0; i < maxNbElements; i++) + if (config.hasHistoryState(i)) + context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistoryState(i), prefix); + else + context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); +} + +void HistoryMineModuleImpl::registerEmbeddings() +{ + if (!wordEmbeddings) + wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize)); +} + diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 89b2de0..e288df2 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -35,6 +35,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "History") modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second))); + else if (splited.first == "HistoryMine") + modules.emplace_back(register_module(name, HistoryMineModule(nameH, splited.second))); else if (splited.first == "NumericColumn") modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second))); else if (splited.first == "UppercaseRate") -- GitLab