From 0652f5872625aadab909bc40d417ee7cf9195fe7 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 10 Oct 2021 14:16:04 +0200 Subject: [PATCH] Added dimmension reduction in Concat module --- reading_machine/src/Classifier.cpp | 3 ++- torch_modules/include/Concat.hpp | 4 +++- torch_modules/src/Concat.cpp | 7 ++++--- torch_modules/src/ContextModule.cpp | 2 +- torch_modules/src/ContextualModule.cpp | 2 +- torch_modules/src/DepthLayerTreeEmbeddingModule.cpp | 2 +- torch_modules/src/DistanceModule.cpp | 2 +- torch_modules/src/FocusedColumnModule.cpp | 2 +- torch_modules/src/HistoryMineModule.cpp | 2 +- torch_modules/src/HistoryModule.cpp | 2 +- torch_modules/src/NumericColumnModule.cpp | 2 +- torch_modules/src/RawInputModule.cpp | 2 +- torch_modules/src/SplitTransModule.cpp | 2 +- torch_modules/src/UppercaseRateModule.cpp | 2 +- trainer/src/MacaonTrain.cpp | 2 +- 15 files changed, 21 insertions(+), 17 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 4e3ded1..e2dd7c9 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -101,7 +101,8 @@ int Classifier::getNbParameters() const int nbParameters = 0; for (auto & t : nn->parameters()) - nbParameters += torch::numel(t); + if (t.requires_grad()) + nbParameters += torch::numel(t); return nbParameters; } diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp index 4c7de25..b6134b7 100644 --- a/torch_modules/include/Concat.hpp +++ b/torch_modules/include/Concat.hpp @@ -9,10 +9,12 @@ class ConcatImpl : public MyModule private : int inputSize; + int outputSize; + torch::nn::Linear dimReduce{nullptr}; public : - ConcatImpl(int inputSize); + ConcatImpl(int inputSize, int outputSize); torch::Tensor forward(torch::Tensor input); int getOutputSize(int sequenceLength); }; diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp index 09d99c6..7ba200c 100644 --- a/torch_modules/src/Concat.cpp +++ b/torch_modules/src/Concat.cpp @@ -1,16 +1,17 @@ #include "Concat.hpp" -ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize) +ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize) { + dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); } torch::Tensor ConcatImpl::forward(torch::Tensor input) { - return input.view({input.size(0), -1}); + return dimReduce(input).view({input.size(0), -1}); } int ConcatImpl::getOutputSize(int sequenceLength) { - return sequenceLength * inputSize; + return sequenceLength * outputSize; } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index b99f6ea..b1c7fd2 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else if (subModuleType == "Transformer") myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options)); else diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 6992524..bd825f7 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index a60433e..acc45d5 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string else if (subModuleType == "GRU") depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); else if (subModuleType == "Concat") - depthModules.emplace_back(register_module(name, Concat(inSize))); + depthModules.emplace_back(register_module(name, Concat(inSize, outSize))); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index a51eea0..0aebe58 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 107e956..62da3de 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp index 25bfcc1..7d1c6f5 100644 --- a/torch_modules/src/HistoryMineModule.cpp +++ b/torch_modules/src/HistoryMineModule.cpp @@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string else if (subModuleType == "CNN") myModule = register_module("myModule", CNN(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index dddfdf7..c897364 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin else if (subModuleType == "CNN") myModule = register_module("myModule", CNN(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index a5001e7..a666fc3 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1)); + myModule = register_module("myModule", Concat(1,1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index d237485..c5dc9a5 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else if (subModuleType == "Transformer") myModule = register_module("myModule", Transformer(inSize, outSize, options)); else diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 5f361a0..ee3fa38 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index ff6e0ac..8d86c74 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1)); + myModule = register_module("myModule", Concat(1, 1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 9846783..cef2bea 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -282,7 +282,7 @@ int MacaonTrain::main() { machine.resetClassifiers(); machine.trainMode(currentEpoch == 0); - fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); + fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); } machine.resetOptimizers(); -- GitLab