diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d7e290cc066fa2a1a842b02f702c5c717390e062 --- /dev/null +++ b/torch_modules/include/ContextualModule.hpp @@ -0,0 +1,39 @@ +#ifndef CONTEXTUALMODULE__H +#define CONTEXTUALMODULE__H + +#include <torch/torch.h> +#include <optional> +#include "Submodule.hpp" +#include "MyModule.hpp" +#include "GRU.hpp" +#include "LSTM.hpp" +#include "Concat.hpp" + +class ContextualModuleImpl : public Submodule +{ + private : + + torch::nn::Embedding wordEmbeddings{nullptr}; + std::shared_ptr<MyModule> myModule{nullptr}; + std::vector<std::string> columns; + std::vector<std::function<std::string(const std::string &)>> functions; + std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets; + std::pair<int,int> window; + int inSize; + int outSize; + std::filesystem::path path; + std::filesystem::path w2vFile; + + public : + + ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path); + torch::Tensor forward(torch::Tensor input); + std::size_t getOutputSize() override; + std::size_t getInputSize() override; + void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void registerEmbeddings() override; +}; +TORCH_MODULE(ContextualModule); + +#endif + diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index 3fa417ba2ffbb566f79316cdd7a9b51f5a14c6d7..9b7efaec8e94d55f745aaf4d16ab7c5f5877c811 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -3,6 +3,7 @@ #include "NeuralNetwork.hpp" #include "ContextModule.hpp" +#include "ContextualModule.hpp" #include "RawInputModule.hpp" #include "SplitTransModule.hpp" #include "AppliableTransModule.hpp" diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 2e13833591fbac8fe99125e122d093d6b1611971..a3129c8eef06a2cf41056571c8b50eb95180f59d 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -148,7 +148,7 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*context.size(2)}); - return myModule->forward(context); + return myModule->forward(context).reshape({input.size(0), -1}); } void ContextModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d5f60788f769fa801b667b1f21815a47b53a4e34 --- /dev/null +++ b/torch_modules/src/ContextualModule.cpp @@ -0,0 +1,194 @@ +#include "ContextualModule.hpp" + +ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path) +{ + setName(name); + + std::regex regex("(?:(?:\\s|\\t)*)Window\\{(.*)\\}(?:(?:\\s|\\t)*)Columns\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)w2v\\{(.*)\\}(?:(?:\\s|\\t)*)Targets\\{(.*)\\}(?:(?:\\s|\\t)*)"); + if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm) + { + try + { + auto splited = util::split(sm.str(1), ' '); + if (splited.size() != 2) + util::myThrow("bad Window, expected 2 indexes"); + window = std::make_pair(std::stoi(splited[0]), std::stoi(splited[1])); + + auto funcColumns = util::split(sm.str(2), ' '); + columns.clear(); + for (auto & funcCol : funcColumns) + { + functions.emplace_back() = getFunction(funcCol); + columns.emplace_back(util::split(funcCol, ':').back()); + } + + auto subModuleType = sm.str(3); + auto subModuleArguments = util::split(sm.str(4), ' '); + + auto options = MyModule::ModuleOptions(true) + .bidirectional(std::stoi(subModuleArguments[0])) + .num_layers(std::stoi(subModuleArguments[1])) + .dropout(std::stof(subModuleArguments[2])) + .complete(std::stoi(subModuleArguments[3])); + + for (auto & target : util::split(sm.str(8), ' ')) + { + auto splited = util::split(target, '.'); + if (splited.size() != 2 and splited.size() != 3) + util::myThrow(fmt::format("invalid target '{}' expected 'object.index(.childIndex)'", target)); + targets.emplace_back(std::make_tuple(Config::str2object(splited[0]), std::stoi(splited[1]), splited.size() == 3 ? std::optional<int>(std::stoi(splited[2])) : std::optional<int>())); + } + + inSize = std::stoi(sm.str(5)); + outSize = std::stoi(sm.str(6)); + if (outSize % 2) + util::myThrow("odd out size is not supported"); + + if (subModuleType == "LSTM") + myModule = register_module("myModule", LSTM(columns.size()*inSize, outSize, options)); + else if (subModuleType == "GRU") + myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); + else if (subModuleType == "Concat") + myModule = register_module("myModule", Concat(inSize)); + else + util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); + + w2vFile = sm.str(7); + + if (!w2vFile.empty()) + { + getDict().loadWord2Vec(this->path / w2vFile); + getDict().setState(Dict::State::Closed); + dictSetPretrained(true); + } + + } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} + })) + util::myThrow(fmt::format("invalid definition '{}'", definition)); +} + +std::size_t ContextualModuleImpl::getOutputSize() +{ + return targets.size()*outSize; +} + +std::size_t ContextualModuleImpl::getInputSize() +{ + return columns.size()*(2+window.second-window.first)+targets.size(); +} + +void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +{ + auto & dict = getDict(); + std::vector<long> contextIndexes; + std::vector<long> targetIndexes; + std::map<long,long> configIndex2ContextIndex; + + contextIndexes.emplace_back(-2); + + for (long i = window.first; i <= window.second; i++) + { + if (config.hasRelativeWordIndex(Config::Object::Buffer, i)) + { + contextIndexes.emplace_back(config.getRelativeWordIndex(Config::Object::Buffer, i)); + configIndex2ContextIndex[contextIndexes.back()] = contextIndexes.size()-1; + } + else + contextIndexes.emplace_back(-1); + } + + for (auto & target : targets) + if (config.hasRelativeWordIndex(std::get<0>(target), std::get<1>(target))) + { + int baseIndex = config.getRelativeWordIndex(std::get<0>(target), std::get<1>(target)); + if (!std::get<2>(target)) + targetIndexes.emplace_back(baseIndex); + else + { + int childIndex = *std::get<2>(target); + auto childs = util::split(config.getAsFeature(Config::childsColName, baseIndex).get(), '|'); + int candidate = -1; + + if (childIndex >= 0 and childIndex < (int)childs.size()) + candidate = std::stoi(childs[childIndex]); + else if (childIndex < 0 and ((int)childs.size())+childIndex >= 0) + candidate = std::stoi(childs[childs.size()+childIndex]); + + targetIndexes.emplace_back(candidate); + } + } + else + targetIndexes.emplace_back(-1); + + for (auto index : contextIndexes) + for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++) + { + auto & col = columns[colIndex]; + if (index == -1) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, Dict::nullValueStr))); + } + else if (index == -2) + { + for (auto & contextElement : context) + contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}({})", col, "_NONE_"))); + } + else + { + int dictIndex; + if (col == Config::idColName) + { + std::string value; + if (config.isCommentPredicted(index)) + value = "ID(comment)"; + else if (config.isMultiwordPredicted(index)) + value = "ID(multiword)"; + else if (config.isTokenPredicted(index)) + value = "ID(token)"; + dictIndex = dict.getIndexOrInsert(value); + } + else if (col == Config::EOSColName) + { + dictIndex = dict.getIndexOrInsert(fmt::format("EOS({})", config.getAsFeature(col, index))); + } + else + dictIndex = dict.getIndexOrInsert(functions[colIndex](config.getAsFeature(col, index))); + + for (auto & contextElement : context) + contextElement.push_back(dictIndex); + } + } + + for (auto index : targetIndexes) + { + if (configIndex2ContextIndex.count(index)) + for (auto & contextElement : context) + contextElement.push_back(configIndex2ContextIndex.at(index)+1); + else + for (auto & contextElement : context) + contextElement.push_back(0); + } +} + +torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) +{ + auto context = wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()-targets.size())).view({input.size(0), (2+window.second-window.first), -1}); + auto focusedIndexes = input.narrow(1, firstInputIndex+getInputSize()-targets.size(), targets.size()); + + auto out = myModule->forward(context); + + std::vector<torch::Tensor> batchElems; + + for (unsigned int i = 0; i < input.size(0); i++) + batchElems.emplace_back(torch::index_select(out[i], 0, focusedIndexes[i]).view({-1})); + + return torch::stack(batchElems); +} + +void ContextualModuleImpl::registerEmbeddings() +{ + wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); + loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile); +} + diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 1b8034b51e2b19ec65294e5ae8f2bc7d50cfa016..40098bc37e96c2008cf4f331b761debfaaf022f9 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -50,7 +50,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def torch::Tensor DistanceModuleImpl::forward(torch::Tensor input) { - return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1}); } std::size_t DistanceModuleImpl::getOutputSize() diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 91d22b03d3f59022487be82158a59c08d4b6fbdb..62c1aef1e364b2616ef48d6559468d944546115e 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -48,7 +48,7 @@ torch::Tensor FocusedColumnModuleImpl::forward(torch::Tensor input) { std::vector<torch::Tensor> outputs; for (unsigned int i = 0; i < focusedBuffer.size()+focusedStack.size(); i++) - outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements)))); + outputs.emplace_back(myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex+i*maxNbElements, maxNbElements))).reshape({input.size(0), -1})); return torch::cat(outputs, 1); } diff --git a/torch_modules/src/GRU.cpp b/torch_modules/src/GRU.cpp index fa6de5ccf0d594166f2544465dc5903a204692d4..a0ab6f22cd88f37b8d6c37d7fac730c850fe7cf3 100644 --- a/torch_modules/src/GRU.cpp +++ b/torch_modules/src/GRU.cpp @@ -2,7 +2,7 @@ GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options)) { - auto gruOptions = torch::nn::GRUOptions(inputSize, outputSize) + auto gruOptions = torch::nn::GRUOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize) .batch_first(std::get<0>(options)) .bidirectional(std::get<1>(options)) .num_layers(std::get<2>(options)) @@ -13,15 +13,7 @@ GRUImpl::GRUImpl(int inputSize, int outputSize, ModuleOptions options) : outputA torch::Tensor GRUImpl::forward(torch::Tensor input) { - auto gruOut = std::get<0>(gru(input)); - - if (outputAll) - return gruOut.reshape({gruOut.size(0), -1}); - - if (gru->options.bidirectional()) - return torch::cat({gruOut.narrow(1,0,1).squeeze(1), gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1)}, 1); - - return gruOut.narrow(1,gruOut.size(1)-1,1).squeeze(1); + return std::get<0>(gru(input)); } int GRUImpl::getOutputSize(int sequenceLength) diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index c326f52ef7645dea75feed3f08b1756be50d8df2..3e09b0aa7760f687b2d70efba5e696f93a45304b 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -38,7 +38,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin torch::Tensor HistoryModuleImpl::forward(torch::Tensor input) { - return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1}); } std::size_t HistoryModuleImpl::getOutputSize() diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index 2844b17a256bac5de90017343fc4c7b2ad466e89..d84af461c789d1f295640d26bfddd18208f1c89f 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -2,7 +2,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outputAll(std::get<4>(options)) { - auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) + auto lstmOptions = torch::nn::LSTMOptions(inputSize, std::get<1>(options) ? outputSize/2 : outputSize) .batch_first(std::get<0>(options)) .bidirectional(std::get<1>(options)) .num_layers(std::get<2>(options)) @@ -13,15 +13,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu torch::Tensor LSTMImpl::forward(torch::Tensor input) { - auto lstmOut = std::get<0>(lstm(input)); - - if (outputAll) - return lstmOut.reshape({lstmOut.size(0), -1}); - - if (lstm->options.bidirectional()) - return torch::cat({lstmOut.narrow(1,0,1).squeeze(1), lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1)}, 1); - - return lstmOut.narrow(1,lstmOut.size(1)-1,1).squeeze(1); + return std::get<0>(lstm(input)); } int LSTMImpl::getOutputSize(int sequenceLength) diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index 685060ff36557645c50c75bf1f049e71ceb38f4f..7dcf1c5d4fe2231b479131adf9d78ce5bd1fccc4 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -29,6 +29,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st std::string nameH = fmt::format("{}_{}", getName(), name); if (splited.first == "Context") modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path))); + else if (splited.first == "Contextual") + modules.emplace_back(register_module(name, ContextualModule(nameH, splited.second, path))); else if (splited.first == "StateName") modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); else if (splited.first == "History") diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 182502324c727217ed498709b3c0ee831bb436e6..b535ded6b006b9249ddc1574c3aae816bb36a050 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -45,7 +45,7 @@ torch::Tensor NumericColumnModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); - return myModule->forward(values); + return myModule->forward(values).reshape({input.size(0), -1}); } std::size_t NumericColumnModuleImpl::getOutputSize() diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index c99c4ae24f6d9ea828fe6e7de4dc4a3637753f15..8f43a2fd310853793e93caaebc348ab1ff18b0be 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -39,7 +39,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def torch::Tensor RawInputModuleImpl::forward(torch::Tensor input) { - return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1}); } std::size_t RawInputModuleImpl::getOutputSize() diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 822969f321d15dd99408a8472b21c7b289ad981a..d4f6d84067329448d1c5b9b5f6ccba01157b279f 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -38,7 +38,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con torch::Tensor SplitTransModuleImpl::forward(torch::Tensor input) { - return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))); + return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, getInputSize()))).reshape({input.size(0), -1}); } std::size_t SplitTransModuleImpl::getOutputSize() diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 818db8b01ae2575df8c0f4b7889214a8b2fdcc5a..0452eb8db781b8e83a1e62069b88c790b1214678 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -43,7 +43,7 @@ torch::Tensor UppercaseRateModuleImpl::forward(torch::Tensor input) { auto context = input.narrow(1, firstInputIndex, getInputSize()); auto values = torch::from_blob(context.data_ptr(), context.sizes(), context.strides(), torch::TensorOptions(torch::kDouble).requires_grad(false).device(NeuralNetworkImpl::device)).to(torch::kFloat).unsqueeze(-1).clone(); - return myModule->forward(values); + return myModule->forward(values).reshape({input.size(0), -1}); } std::size_t UppercaseRateModuleImpl::getOutputSize()