diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index 4e3ded1fc8b7517234c247c6daadd5de3b196790..e2dd7c91cb9fd3fbb6be6191c17dbf4ca2bd2fec 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -101,7 +101,8 @@ int Classifier::getNbParameters() const int nbParameters = 0; for (auto & t : nn->parameters()) - nbParameters += torch::numel(t); + if (t.requires_grad()) + nbParameters += torch::numel(t); return nbParameters; } diff --git a/torch_modules/include/Concat.hpp b/torch_modules/include/Concat.hpp index 4c7de25154e933448394c15add2f6e052c1f85ff..b6134b70afaae7aa28557e1428df898887feac3e 100644 --- a/torch_modules/include/Concat.hpp +++ b/torch_modules/include/Concat.hpp @@ -9,10 +9,12 @@ class ConcatImpl : public MyModule private : int inputSize; + int outputSize; + torch::nn::Linear dimReduce{nullptr}; public : - ConcatImpl(int inputSize); + ConcatImpl(int inputSize, int outputSize); torch::Tensor forward(torch::Tensor input); int getOutputSize(int sequenceLength); }; diff --git a/torch_modules/src/Concat.cpp b/torch_modules/src/Concat.cpp index 09d99c64797133e7778a6dfaf055f3faab6a7b40..7ba200c3c5e322d35aeb8fb4870d95fdb0af0bd3 100644 --- a/torch_modules/src/Concat.cpp +++ b/torch_modules/src/Concat.cpp @@ -1,16 +1,17 @@ #include "Concat.hpp" -ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize) +ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize) { + dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize)); } torch::Tensor ConcatImpl::forward(torch::Tensor input) { - return input.view({input.size(0), -1}); + return dimReduce(input).view({input.size(0), -1}); } int ConcatImpl::getOutputSize(int sequenceLength) { - return sequenceLength * inputSize; + return sequenceLength * outputSize; } diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index b99f6ea34c60745d4b9f0773d628f69fb92f3839..b1c7fd2a1baf30b44415fb6e437a0ec099f469b4 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else if (subModuleType == "Transformer") myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options)); else diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 69925249a5c88ed2c09b6668b132caef815b02dd..bd825f7c1cd1e91c4c20ceca354bf269469a2860 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp index a60433e893f6930467e78aa6c14a5737b6552166..acc45d557d3d15d0f488db9b55c260f9f9544342 100644 --- a/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp +++ b/torch_modules/src/DepthLayerTreeEmbeddingModule.cpp @@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string else if (subModuleType == "GRU") depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); else if (subModuleType == "Concat") - depthModules.emplace_back(register_module(name, Concat(inSize))); + depthModules.emplace_back(register_module(name, Concat(inSize, outSize))); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } diff --git a/torch_modules/src/DistanceModule.cpp b/torch_modules/src/DistanceModule.cpp index a51eea090d8d24360be0aaf3a0ba307302b9b8a2..0aebe58f88d07d745011039ddae2c8181a9929a3 100644 --- a/torch_modules/src/DistanceModule.cpp +++ b/torch_modules/src/DistanceModule.cpp @@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/FocusedColumnModule.cpp b/torch_modules/src/FocusedColumnModule.cpp index 107e9561fa67489bffcf6b5091c9435289478190..62da3de4a5fc30af10f100add7b27954f9d7d99c 100644 --- a/torch_modules/src/FocusedColumnModule.cpp +++ b/torch_modules/src/FocusedColumnModule.cpp @@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryMineModule.cpp b/torch_modules/src/HistoryMineModule.cpp index 25bfcc1a22786453e0695512b0f1ba57b67e12d0..7d1c6f573639219ca09bb63b197abb4fb09ca1d7 100644 --- a/torch_modules/src/HistoryMineModule.cpp +++ b/torch_modules/src/HistoryMineModule.cpp @@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string else if (subModuleType == "CNN") myModule = register_module("myModule", CNN(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/HistoryModule.cpp b/torch_modules/src/HistoryModule.cpp index dddfdf730d89050643fb5194ad611cf1f25b8330..c8973643b4188b194a4b75a003a9459acb01ee6f 100644 --- a/torch_modules/src/HistoryModule.cpp +++ b/torch_modules/src/HistoryModule.cpp @@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin else if (subModuleType == "CNN") myModule = register_module("myModule", CNN(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/NumericColumnModule.cpp b/torch_modules/src/NumericColumnModule.cpp index a5001e7c8ae3619f60bd1150e99817ce2d5448f0..a666fc3ae6651d26c5484dc5aa0ea37dd83fedaf 100644 --- a/torch_modules/src/NumericColumnModule.cpp +++ b/torch_modules/src/NumericColumnModule.cpp @@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1)); + myModule = register_module("myModule", Concat(1,1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/torch_modules/src/RawInputModule.cpp b/torch_modules/src/RawInputModule.cpp index d237485567a4ee8068466d696a869235e1c30381..c5dc9a55330e48802da176b9dd6f26a424ad6460 100644 --- a/torch_modules/src/RawInputModule.cpp +++ b/torch_modules/src/RawInputModule.cpp @@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else if (subModuleType == "Transformer") myModule = register_module("myModule", Transformer(inSize, outSize, options)); else diff --git a/torch_modules/src/SplitTransModule.cpp b/torch_modules/src/SplitTransModule.cpp index 5f361a00b3d6cc42d71d734ee598ddbbf546c394..ee3fa389977cd648d142b450ca43b6a93d358fa6 100644 --- a/torch_modules/src/SplitTransModule.cpp +++ b/torch_modules/src/SplitTransModule.cpp @@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(inSize, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(inSize)); + myModule = register_module("myModule", Concat(inSize, outSize)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); diff --git a/torch_modules/src/UppercaseRateModule.cpp b/torch_modules/src/UppercaseRateModule.cpp index ff6e0acbf565454f031e0f177216398031635cd9..8d86c74ddc56bb0a88ed0e19da05560dc2aef86f 100644 --- a/torch_modules/src/UppercaseRateModule.cpp +++ b/torch_modules/src/UppercaseRateModule.cpp @@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st else if (subModuleType == "GRU") myModule = register_module("myModule", GRU(1, outSize, options)); else if (subModuleType == "Concat") - myModule = register_module("myModule", Concat(1)); + myModule = register_module("myModule", Concat(1, 1)); else util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 98467837dd2e8abb128dd283d6f7f51265937b33..cef2beaa7ae21b88af745b2264c85fd9a73c4c74 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -282,7 +282,7 @@ int MacaonTrain::main() { machine.resetClassifiers(); machine.trainMode(currentEpoch == 0); - fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); + fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); } machine.resetOptimizers();