diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 9aad4f048b3e065e554732bdf1b56326bbff78c3..c9327339675b254fa33cfe9d6773598fdcfdda79 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -45,8 +45,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto appliableTransitions = machine.getTransitionSet(elements[index].config.getState()).getAppliableTransitions(elements[index].config); elements[index].config.setAppliableTransitions(appliableTransitions); - auto context = classifier.getNN()->extractContext(elements[index].config).back(); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); + auto neuralInput = classifier.getNN()->extractContext(elements[index].config); auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, elements[index].config.getState()).squeeze(0), 0); float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); diff --git a/torch_modules/include/AppliableTransModule.hpp b/torch_modules/include/AppliableTransModule.hpp index 5e6f9e461109eac691920e9763106681f1461f38..98f5fe13a0b5e5ca9c46caa2318d4ab054509b69 100644 --- a/torch_modules/include/AppliableTransModule.hpp +++ b/torch_modules/include/AppliableTransModule.hpp @@ -19,7 +19,7 @@ class AppliableTransModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(AppliableTransModule); diff --git a/torch_modules/include/ContextModule.hpp b/torch_modules/include/ContextModule.hpp index 585188793b32779c8da3f92bfe6603b71ccbb6ae..c2e0668135942d7dbdfaf0fff9e6f5d72fc66880 100644 --- a/torch_modules/include/ContextModule.hpp +++ b/torch_modules/include/ContextModule.hpp @@ -30,7 +30,7 @@ class ContextModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(ContextModule); diff --git a/torch_modules/include/ContextualModule.hpp b/torch_modules/include/ContextualModule.hpp index e7fb2a90e62fe2840de0e56a3050960c137d67d3..8483b1a1c6199660fe88c6054888144a6bf1f6e0 100644 --- a/torch_modules/include/ContextualModule.hpp +++ b/torch_modules/include/ContextualModule.hpp @@ -31,7 +31,7 @@ class ContextualModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(ContextualModule); diff --git a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp index 6da8943b132306bc57ef27780269be9263b1037f..3621e6e5df8165963788872100df71c4adaa7aad 100644 --- a/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp +++ b/torch_modules/include/DepthLayerTreeEmbeddingModule.hpp @@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(DepthLayerTreeEmbeddingModule); diff --git a/torch_modules/include/DistanceModule.hpp b/torch_modules/include/DistanceModule.hpp index 3702ad58e6f1cf084c8bd0bd823e62bb4112c774..bafa0b8e3d852ab22e6f787ec005b14c0ff81f5f 100644 --- a/torch_modules/include/DistanceModule.hpp +++ b/torch_modules/include/DistanceModule.hpp @@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(DistanceModule); diff --git a/torch_modules/include/FocusedColumnModule.hpp b/torch_modules/include/FocusedColumnModule.hpp index af370c64509546a8adc3c3558f3775f14dbbfd03..a7df33187f2486e9aac28d60ca3dbd5a77492ecc 100644 --- a/torch_modules/include/FocusedColumnModule.hpp +++ b/torch_modules/include/FocusedColumnModule.hpp @@ -29,7 +29,7 @@ class FocusedColumnModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(FocusedColumnModule); diff --git a/torch_modules/include/HistoryModule.hpp b/torch_modules/include/HistoryModule.hpp index 54418a6398ba15f24821def4059b719e24737746..b4a725b8a4cf2ec79b1891cb39cfa21530f2278f 100644 --- a/torch_modules/include/HistoryModule.hpp +++ b/torch_modules/include/HistoryModule.hpp @@ -25,7 +25,7 @@ class HistoryModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(HistoryModule); diff --git a/torch_modules/include/ModularNetwork.hpp b/torch_modules/include/ModularNetwork.hpp index ed73c301bd90134f50b98a4778d5a6539b54f9aa..31685e29ea45326c94acbfa62efe9dcc280a747b 100644 --- a/torch_modules/include/ModularNetwork.hpp +++ b/torch_modules/include/ModularNetwork.hpp @@ -25,12 +25,13 @@ class ModularNetworkImpl : public NeuralNetworkImpl MLP mlp{nullptr}; std::vector<std::shared_ptr<Submodule>> modules; std::map<std::string,torch::nn::Linear> outputLayersPerState; + std::size_t totalInputSize{0}; public : ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); torch::Tensor forward(torch::Tensor input, const std::string & state) override; - std::vector<std::vector<long>> extractContext(Config & config) override; + torch::Tensor extractContext(Config & config) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 8215ad2fff9438ab0b6e133f1591b9cddc7c369e..ffbcdea03d406e38aaef06830b54e2809f27b0a4 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -15,7 +15,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder public : virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; - virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; + virtual torch::Tensor extractContext(Config & config) = 0; virtual void registerEmbeddings() = 0; virtual void saveDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0; diff --git a/torch_modules/include/NumericColumnModule.hpp b/torch_modules/include/NumericColumnModule.hpp index e3d46aece017226572fbc469e53f85dfbaddc518..3ee9cb217f31b1eeaf9e0d44554e19ce078d1923 100644 --- a/torch_modules/include/NumericColumnModule.hpp +++ b/torch_modules/include/NumericColumnModule.hpp @@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(NumericColumnModule); diff --git a/torch_modules/include/RandomNetwork.hpp b/torch_modules/include/RandomNetwork.hpp index 3c559e9c393146a86bb9ffc1c6f3dd42f0a89ac2..33d99a1455b521f88e972d002103baa4053ee5da 100644 --- a/torch_modules/include/RandomNetwork.hpp +++ b/torch_modules/include/RandomNetwork.hpp @@ -13,7 +13,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); torch::Tensor forward(torch::Tensor input, const std::string & state) override; - std::vector<std::vector<long>> extractContext(Config &) override; + torch::Tensor extractContext(Config &) override; void registerEmbeddings() override; void saveDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override; diff --git a/torch_modules/include/RawInputModule.hpp b/torch_modules/include/RawInputModule.hpp index 26237e2fefc63093b693ebf1860f99492e7beeed..0ca658b978bc2d0bacb6a6514e64751f5305d5a5 100644 --- a/torch_modules/include/RawInputModule.hpp +++ b/torch_modules/include/RawInputModule.hpp @@ -25,7 +25,7 @@ class RawInputModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(RawInputModule); diff --git a/torch_modules/include/SplitTransModule.hpp b/torch_modules/include/SplitTransModule.hpp index 1ef1796c490f2192b9d12fb1f2a77bdfe8ec786c..b88491e6d6f0f019b7c3c5e8c0f8482308fa5b72 100644 --- a/torch_modules/include/SplitTransModule.hpp +++ b/torch_modules/include/SplitTransModule.hpp @@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(SplitTransModule); diff --git a/torch_modules/include/StateNameModule.hpp b/torch_modules/include/StateNameModule.hpp index 3abfe826021001fe3bc1d04016bea3ed353069b2..ace1cbc63d66e164a155a9b9b611a7d74c82233e 100644 --- a/torch_modules/include/StateNameModule.hpp +++ b/torch_modules/include/StateNameModule.hpp @@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(StateNameModule); diff --git a/torch_modules/include/Submodule.hpp b/torch_modules/include/Submodule.hpp index 1dbbdc7e46844a910a5d0884c46e8f6e62f192ae..f4722bf195537ea4ae0d1b93f9bae38cab268c6e 100644 --- a/torch_modules/include/Submodule.hpp +++ b/torch_modules/include/Submodule.hpp @@ -24,7 +24,7 @@ class Submodule : public torch::nn::Module, public DictHolder void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix); virtual std::size_t getOutputSize() = 0; virtual std::size_t getInputSize() = 0; - virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; + virtual void addToContext(torch::Tensor & context, const Config & config) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0; virtual void registerEmbeddings() = 0; std::function<std::string(const std::string &)> getFunction(const std::string functionNames); diff --git a/torch_modules/include/UppercaseRateModule.hpp b/torch_modules/include/UppercaseRateModule.hpp index dcfb89c3d08f42dfd4f595d79f48dcd639594559..94956613a7974bfbcb91e7ea2768f6ba195a3ecd 100644 --- a/torch_modules/include/UppercaseRateModule.hpp +++ b/torch_modules/include/UppercaseRateModule.hpp @@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule torch::Tensor forward(torch::Tensor input); std::size_t getOutputSize() override; std::size_t getInputSize() override; - void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; + void addToContext(torch::Tensor & context, const Config & config) override; void registerEmbeddings() override; }; TORCH_MODULE(UppercaseRateModule); diff --git a/torch_modules/src/AppliableTransModule.cpp b/torch_modules/src/AppliableTransModule.cpp index c50586f1e49ed7001d0ecc6643b2f6af36047226..7a5c830cedff73c185e930141611b396cfdb8b44 100644 --- a/torch_modules/src/AppliableTransModule.cpp +++ b/torch_modules/src/AppliableTransModule.cpp @@ -20,15 +20,12 @@ std::size_t AppliableTransModuleImpl::getInputSize() return nbTrans; } -void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void AppliableTransModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & appliableTrans = config.getAppliableTransitions(); - for (auto & contextElement : context) - for (int i = 0; i < nbTrans; i++) - if (i < (int)appliableTrans.size()) - contextElement.emplace_back(appliableTrans[i]); - else - contextElement.emplace_back(0); + for (int i = 0; i < nbTrans; i++) + if (i < (int)appliableTrans.size()) + context[firstInputIndex+i] = appliableTrans[i]; } void AppliableTransModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index f457f314b3595bf0c51d438893c3743609605a98..19c7972123aab7fa47822aaf5e290cb69fbdb54a 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -83,7 +83,7 @@ std::size_t ContextModuleImpl::getInputSize() return columns.size()*(targets.size()); } -void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void ContextModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::vector<long> contextIndexes; @@ -125,24 +125,22 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c contextIndexes.emplace_back(-3); } + int insertIndex = 0; for (auto index : contextIndexes) for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++) { auto & col = columns[colIndex]; if (index == -1) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col); } else if (index == -2) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col); } else if (index == -3) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col); } else { @@ -162,9 +160,9 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c dictIndex = dict.getIndexOrInsert(featureValue, col); } - for (auto & contextElement : context) - contextElement.push_back(dictIndex); + context[firstInputIndex+insertIndex] = dictIndex; } + insertIndex++; } } diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index f5f8562e26b811af74966ec8343eeaacb6774e2f..1c569831a4b15f71a8314326e5c7eccd96c337d6 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -87,7 +87,7 @@ std::size_t ContextualModuleImpl::getInputSize() return columns.size()*(4+window.second-window.first)+targets.size(); } -void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void ContextualModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::vector<long> contextIndexes; @@ -132,24 +132,23 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context else targetIndexes.emplace_back(-1); + int insertIndex = 0; + for (auto index : contextIndexes) for (unsigned int colIndex = 0; colIndex < columns.size(); colIndex++) { auto & col = columns[colIndex]; if (index == -1) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col); } else if (index == -2) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::noChildValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::noChildValueStr, col); } else if (index == -3) { - for (auto & contextElement : context) - contextElement.push_back(dict.getIndexOrInsert(Dict::oobValueStr, col)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::oobValueStr, col); } else { @@ -169,33 +168,32 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context dictIndex = dict.getIndexOrInsert(functions[colIndex](featureValue), col); } - for (auto & contextElement : context) - contextElement.push_back(dictIndex); + context[firstInputIndex+insertIndex] = dictIndex; } + + insertIndex++; } for (auto index : targetIndexes) { if (configIndex2ContextIndex.count(index)) { - for (auto & contextElement : context) - contextElement.push_back(configIndex2ContextIndex.at(index)); + context[firstInputIndex+insertIndex] = configIndex2ContextIndex.at(index); } else { - for (auto & contextElement : context) - { - // -1 == doesn't exist (s.0 when no stack) - if (index == -1) - contextElement.push_back(0); - // -2 == nochild - else if (index == -2) - contextElement.push_back(1); - // other == out of context bounds - else - contextElement.push_back(2); - } + // -1 == doesn't exist (s.0 when no stack) + if (index == -1) + context[firstInputIndex+insertIndex] = 0; + // -2 == nochild + else if (index == -2) + context[firstInputIndex+insertIndex] = 1; + // other == out of context bounds + else + context[firstInputIndex+insertIndex] = 2; } + + insertIndex++; } } diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index 06c0b5fe222ebbddfa644feed25ff1a72249ce45..ac906908408baf2c8375618294a9892929bc3062 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -84,7 +84,7 @@ std::size_t DepthLayerTreeEmbeddingModuleImpl::getInputSize() return inputSize; } -void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void DepthLayerTreeEmbeddingModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::vector<long> focusedIndexes; @@ -98,30 +98,34 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon else focusedIndexes.emplace_back(-1); - for (auto & contextElement : context) - for (auto index : focusedIndexes) + int insertIndex = 0; + for (auto index : focusedIndexes) + { + std::vector<std::string> childs{std::to_string(index)}; + + for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) { - std::vector<std::string> childs{std::to_string(index)}; + std::vector<std::string> newChilds; + for (auto & child : childs) + if (config.has(Config::childsColName, std::stoi(child), 0)) + { + auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|'); + newChilds.insert(newChilds.end(), val.begin(), val.end()); + } + childs = newChilds; - for (unsigned int depth = 0; depth < maxElemPerDepth.size(); depth++) - { - std::vector<std::string> newChilds; - for (auto & child : childs) - if (config.has(Config::childsColName, std::stoi(child), 0)) - { - auto val = util::split(config.getAsFeature(Config::childsColName, std::stoi(child)).get(), '|'); - newChilds.insert(newChilds.end(), val.begin(), val.end()); - } - childs = newChilds; - - for (int i = 0; i < maxElemPerDepth[depth]; i++) - for (auto & col : columns) - if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col)); - else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, col)); - } + for (int i = 0; i < maxElemPerDepth[depth]; i++) + for (auto & col : columns) + { + if (i < (int)newChilds.size() and config.has(col, std::stoi(newChilds[i]), 0)) + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(config.getAsFeature(col,std::stoi(newChilds[i])), col); + else + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, col); + + insertIndex++; + } } + } } void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index 2e71e25cdfd6b174af1115ef636e28cc581365e3..fddf2e00867c73b9a8d7559c03bafcc6a059177f 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -63,7 +63,7 @@ std::size_t DistanceModuleImpl::getInputSize() return (fromBuffer.size()+fromStack.size()) * (toBuffer.size()+toStack.size()); } -void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void DistanceModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::vector<long> fromIndexes, toIndexes; @@ -88,25 +88,26 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, std::string prefix = "DISTANCE"; - for (auto & contextElement : context) - { - for (auto from : fromIndexes) - for (auto to : toIndexes) + int insertIndex = 0; + for (auto from : fromIndexes) + for (auto to : toIndexes) + { + if (from == -1 or to == -1) + { + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); + } + else { - if (from == -1 or to == -1) - { - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); - continue; - } - long dist = std::abs(config.getRelativeDistance(from, to)); if (dist <= threshold) - contextElement.emplace_back(dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), "")); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}({})", prefix, dist), ""); else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::unknownValueStr, prefix)); + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::unknownValueStr, prefix); } - } + + insertIndex++; + } } void DistanceModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 115f918ad3c845a52f1366b277d42b6b35e4b616..3fc25f0ad53ef521f05da2bdbc5aabf0216ef096 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -82,7 +82,7 @@ std::size_t FocusedColumnModuleImpl::getInputSize() return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; } -void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void FocusedColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::vector<long> focusedIndexes; @@ -96,63 +96,67 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont else focusedIndexes.emplace_back(-1); - for (auto & contextElement : context) + int insertIndex = 0; + for (auto index : focusedIndexes) { - for (auto index : focusedIndexes) + if (index == -1) { - if (index == -1) + for (int i = 0; i < maxNbElements; i++) { - for (int i = 0; i < maxNbElements; i++) - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, column)); - continue; + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, column); + insertIndex++; } + continue; + } - std::vector<std::string> elements; - if (column == "FORM") - { - auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get())); - - //TODO don't use nullValueStr here - for (int i = 0; i < maxNbElements; i++) - if (i < (int)asUtf8.size()) - elements.emplace_back(fmt::format("{}", asUtf8[i])); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (column == "FEATS") - { - auto splited = util::split(func(config.getAsFeature(column, index).get()), '|'); + std::vector<std::string> elements; + if (column == "FORM") + { + auto asUtf8 = util::splitAsUtf8(func(config.getAsFeature(column, index).get())); + + //TODO don't use nullValueStr here + for (int i = 0; i < maxNbElements; i++) + if (i < (int)asUtf8.size()) + elements.emplace_back(fmt::format("{}", asUtf8[i])); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (column == "FEATS") + { + auto splited = util::split(func(config.getAsFeature(column, index).get()), '|'); - for (int i = 0; i < maxNbElements; i++) - if (i < (int)splited.size()) - elements.emplace_back(splited[i]); - else - elements.emplace_back(Dict::nullValueStr); - } - else if (column == "ID") - { - if (config.isTokenPredicted(index)) - elements.emplace_back("TOKEN"); - else if (config.isMultiwordPredicted(index)) - elements.emplace_back("MULTIWORD"); - else if (config.isEmptyNodePredicted(index)) - elements.emplace_back("EMPTYNODE"); - } - else if (column == "EOS") - { - bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1; - elements.emplace_back(fmt::format("{}", isEOS)); - } - else - { - elements.emplace_back(func(config.getAsFeature(column, index))); - } + for (int i = 0; i < maxNbElements; i++) + if (i < (int)splited.size()) + elements.emplace_back(splited[i]); + else + elements.emplace_back(Dict::nullValueStr); + } + else if (column == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("TOKEN"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("MULTIWORD"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("EMPTYNODE"); + } + else if (column == "EOS") + { + bool isEOS = func(config.getAsFeature(Config::EOSColName, index)) == Config::EOSSymbol1; + elements.emplace_back(fmt::format("{}", isEOS)); + } + else + { + elements.emplace_back(func(config.getAsFeature(column, index))); + } - if ((int)elements.size() != maxNbElements) - util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements)); + if ((int)elements.size() != maxNbElements) + util::myThrow(fmt::format("elements.size ({}) != maxNbElements ({})", elements.size(), maxNbElements)); - for (auto & element : elements) - contextElement.emplace_back(dict.getIndexOrInsert(element, column)); + for (auto & element : elements) + { + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(element, column); + insertIndex++; } } } diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index 7d0912ce154af1025d409df3f7d3de4f40eae683..4a9033fe01e989301b12553245033c1a876b5ace 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -53,18 +53,17 @@ std::size_t HistoryModuleImpl::getInputSize() return maxNbElements; } -void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void HistoryModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); std::string prefix = "HISTORY"; - for (auto & contextElement : context) - for (int i = 0; i < maxNbElements; i++) - if (config.hasHistory(i)) - contextElement.emplace_back(dict.getIndexOrInsert(config.getHistory(i), prefix)); - else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); + for (int i = 0; i < maxNbElements; i++) + if (config.hasHistory(i)) + context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistory(i), prefix); + else + context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); } void HistoryModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/ModularNetwork.cpp b/torch_modules/src/ModularNetwork.cpp index c936f85b75ebfc4d6ed9686b091a183d53bc5adc..1c39f186db7bdc62e6e30c3771974dacefa8af63 100644 --- a/torch_modules/src/ModularNetwork.cpp +++ b/torch_modules/src/ModularNetwork.cpp @@ -69,6 +69,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st currentOutputSize += modules.back()->getOutputSize(); } + totalInputSize = currentInputSize; + if (mlpDef.empty()) util::myThrow("no MLP definition found"); if (inputDropout.is_empty()) @@ -95,9 +97,9 @@ torch::Tensor ModularNetworkImpl::forward(torch::Tensor input, const std::string return outputLayersPerState.at(state)(mlp(totalInput)); } -std::vector<std::vector<long>> ModularNetworkImpl::extractContext(Config & config) +torch::Tensor ModularNetworkImpl::extractContext(Config & config) { - std::vector<std::vector<long>> context(1); + torch::Tensor context = torch::zeros({totalInputSize}, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device)); for (auto & mod : modules) mod->addToContext(context, config); return context; diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index 15d2b19ef681539418dbd1c64dcffb27f7eabe54..49d3016bb00452efee0ffa7fc0d45d5a80a58bae 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -60,7 +60,7 @@ std::size_t NumericColumnModuleImpl::getInputSize() return focusedBuffer.size() + focusedStack.size(); } -void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void NumericColumnModuleImpl::addToContext(torch::Tensor & context, const Config & config) { std::vector<long> focusedIndexes; @@ -73,21 +73,21 @@ void NumericColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont else focusedIndexes.emplace_back(-1); - for (auto & contextElement : context) - for (auto index : focusedIndexes) + int insertIndex = 0; + for (auto index : focusedIndexes) + { + double res = 0.0; + if (index >= 0) { - double res = 0.0; - if (index >= 0) - { - auto value = config.getAsFeature(column, index).get(); - try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);} - catch (std::exception & e) - {util::myThrow(fmt::format("{} for '{}'", e.what(), value));} - } - - contextElement.emplace_back(0); - std::memcpy(&contextElement.back(), &res, sizeof res); + auto value = config.getAsFeature(column, index).get(); + try {res = (value == "_" or value == "NA") ? defaultValue : std::stof(value);} + catch (std::exception & e) + {util::myThrow(fmt::format("{} for '{}'", e.what(), value));} } + + //TODO : Check if this works + context[firstInputIndex+insertIndex] = res; + } } void NumericColumnModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/RandomNetwork.cpp b/torch_modules/src/RandomNetwork.cpp index 87a604636595062008ecbd5d442111b4101c8b39..b05d1aa00b26677500bc7f8a2acb59ea12a2cbcd 100644 --- a/torch_modules/src/RandomNetwork.cpp +++ b/torch_modules/src/RandomNetwork.cpp @@ -13,9 +13,10 @@ torch::Tensor RandomNetworkImpl::forward(torch::Tensor input, const std::string return torch::randn({input.size(0), (long)nbOutputsPerState[state]}, torch::TensorOptions().device(device).requires_grad(true)); } -std::vector<std::vector<long>> RandomNetworkImpl::extractContext(Config &) +torch::Tensor RandomNetworkImpl::extractContext(Config &) { - return std::vector<std::vector<long>>{{0}}; + torch::Tensor context; + return context; } void RandomNetworkImpl::registerEmbeddings() diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index d948386e6f2d7ffd376c8170b6a670f895a48948..2d6bd62164e13adb8b21547cd895490a0f173e59 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -11,6 +11,9 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def leftWindow = std::stoi(sm.str(1)); rightWindow = std::stoi(sm.str(2)); + if (leftWindow < 0 or rightWindow < 0) + util::myThrow(fmt::format("Invalid negative values for leftWindow({}) or rightWindow({})", leftWindow, rightWindow)); + auto subModuleType = sm.str(3); auto subModuleArguments = util::split(sm.str(4), ' '); @@ -54,27 +57,30 @@ std::size_t RawInputModuleImpl::getInputSize() return leftWindow + rightWindow + 1; } -void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void RawInputModuleImpl::addToContext(torch::Tensor & context, const Config & config) { - if (leftWindow < 0 or rightWindow < 0) - return; - std::string prefix = "LETTER"; - auto & dict = getDict(); - for (auto & contextElement : context) + + int insertIndex = 0; + for (int i = 0; i < leftWindow; i++) { - for (int i = 0; i < leftWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix)); - else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); - - for (int i = 0; i <= rightWindow; i++) - if (config.hasCharacter(config.getCharacterIndex()+i)) - contextElement.push_back(dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix)); - else - contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr, prefix)); + if (config.hasCharacter(config.getCharacterIndex()-leftWindow+i)) + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()-leftWindow+i)), prefix); + else + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); + + insertIndex++; + } + + for (int i = 0; i <= rightWindow; i++) + { + if (config.hasCharacter(config.getCharacterIndex()+i)) + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(fmt::format("{}", config.getLetter(config.getCharacterIndex()+i)), prefix); + else + context[firstInputIndex+insertIndex] = dict.getIndexOrInsert(Dict::nullValueStr, prefix); + + insertIndex++; } } diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 0c1de2e7f5f9a1dbe7003ac24cd02d474d399048..dcb78e1f42464871301c2556da000f61efb63b28 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -51,16 +51,15 @@ std::size_t SplitTransModuleImpl::getInputSize() return maxNbTrans; } -void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void SplitTransModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); auto & splitTransitions = config.getAppliableSplitTransitions(); - for (auto & contextElement : context) - for (int i = 0; i < maxNbTrans; i++) - if (i < (int)splitTransitions.size()) - contextElement.emplace_back(dict.getIndexOrInsert(splitTransitions[i]->getName(), "")); - else - contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr, "")); + for (int i = 0; i < maxNbTrans; i++) + if (i < (int)splitTransitions.size()) + context[firstInputIndex+i] = dict.getIndexOrInsert(splitTransitions[i]->getName(), ""); + else + context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, ""); } void SplitTransModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/StateNameModule.cpp b/torch_modules/src/StateNameModule.cpp index 0c642b947b78a69a64490ab4e2dc7f070b3277af..f3ac97753a08a2529a0708295dfd1cc21231f974 100644 --- a/torch_modules/src/StateNameModule.cpp +++ b/torch_modules/src/StateNameModule.cpp @@ -29,11 +29,10 @@ std::size_t StateNameModuleImpl::getInputSize() return 1; } -void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void StateNameModuleImpl::addToContext(torch::Tensor & context, const Config & config) { auto & dict = getDict(); - for (auto & contextElement : context) - contextElement.emplace_back(dict.getIndexOrInsert(config.getState(), "")); + context[firstInputIndex] = dict.getIndexOrInsert(config.getState(), ""); } void StateNameModuleImpl::registerEmbeddings() diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index 0452eb8db781b8e83a1e62069b88c790b1214678..7c846150c1e96af8f7886520592b8d4363e9c78d 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -56,7 +56,7 @@ std::size_t UppercaseRateModuleImpl::getInputSize() return focusedBuffer.size() + focusedStack.size(); } -void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & context, const Config & config) +void UppercaseRateModuleImpl::addToContext(torch::Tensor & context, const Config & config) { std::vector<long> focusedIndexes; @@ -69,25 +69,24 @@ void UppercaseRateModuleImpl::addToContext(std::vector<std::vector<long>> & cont else focusedIndexes.emplace_back(-1); - for (auto & contextElement : context) + int insertIndex = 0; + for (auto index : focusedIndexes) { - for (auto index : focusedIndexes) + double res = -1.0; + if (index >= 0) { - double res = -1.0; - if (index >= 0) - { - auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get()); - int nbUpper = 0; - for (auto & letter : word) - if (util::isUppercase(letter)) - nbUpper++; - if (word.size() > 0) - res = 1.0*nbUpper/word.size(); - } - - contextElement.emplace_back(0); - std::memcpy(&contextElement.back(), &res, sizeof res); + auto word = util::splitAsUtf8(config.getAsFeature("FORM", index).get()); + int nbUpper = 0; + for (auto & letter : word) + if (util::isUppercase(letter)) + nbUpper++; + if (word.size() > 0) + res = 1.0*nbUpper/word.size(); } + + //TODO : Check if this works + context[firstInputIndex+insertIndex] = res; + insertIndex++; } } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index dfa465e8a524134b55c7887892940fe0bc1a01cd..a5088d0eebd1dd5e60449632718aed378d7c893c 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -34,7 +34,7 @@ class Trainer int lastSavedIndex{0}; void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle); - void addContext(std::vector<std::vector<long>> & context); + void addContext(torch::Tensor & context); void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes); }; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 56ecc44cbfdd3fdf069856b4da3963665ae9c068..4143b2cb5a961f9db485dbd9891a2d35e569fc0c 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -65,7 +65,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: auto appliableTransitions = machine.getTransitionSet(config.getState()).getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); - std::vector<std::vector<long>> context; + torch::Tensor context; try { @@ -92,8 +92,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto & classifier = *machine.getClassifier(config.getState()); - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device); - auto prediction = classifier.isRegression() ? classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(neuralInput, config.getState()).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()->forward(context, config.getState()).squeeze(0) : torch::softmax(classifier.getNN()->forward(context, config.getState()).squeeze(0), 0); entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates; @@ -154,7 +153,7 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std: if (!exampleIsBanned) { - totalNbExamples += context.size(); + totalNbExamples += 1; if (totalNbExamples >= (int)safetyNbExamplesMax) util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); @@ -295,12 +294,11 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: classes.clear(); } -void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) +void Trainer::Examples::addContext(torch::Tensor & context) { - for (auto & element : context) - contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); + contexts.emplace_back(context); - currentExampleIndex += context.size(); + currentExampleIndex += 1; } void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes)