Commit 0652f587 authored by Franck Dary's avatar Franck Dary
Browse files

Added dimmension reduction in Concat module

parent 5800a6f3
...@@ -101,7 +101,8 @@ int Classifier::getNbParameters() const ...@@ -101,7 +101,8 @@ int Classifier::getNbParameters() const
int nbParameters = 0; int nbParameters = 0;
for (auto & t : nn->parameters()) for (auto & t : nn->parameters())
nbParameters += torch::numel(t); if (t.requires_grad())
nbParameters += torch::numel(t);
return nbParameters; return nbParameters;
} }
......
...@@ -9,10 +9,12 @@ class ConcatImpl : public MyModule ...@@ -9,10 +9,12 @@ class ConcatImpl : public MyModule
private : private :
int inputSize; int inputSize;
int outputSize;
torch::nn::Linear dimReduce{nullptr};
public : public :
ConcatImpl(int inputSize); ConcatImpl(int inputSize, int outputSize);
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength); int getOutputSize(int sequenceLength);
}; };
......
#include "Concat.hpp" #include "Concat.hpp"
ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize) ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize)
{ {
dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
} }
torch::Tensor ConcatImpl::forward(torch::Tensor input) torch::Tensor ConcatImpl::forward(torch::Tensor input)
{ {
return input.view({input.size(0), -1}); return dimReduce(input).view({input.size(0), -1});
} }
int ConcatImpl::getOutputSize(int sequenceLength) int ConcatImpl::getOutputSize(int sequenceLength)
{ {
return sequenceLength * inputSize; return sequenceLength * outputSize;
} }
...@@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin ...@@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else if (subModuleType == "Transformer") else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options));
else else
......
...@@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & ...@@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options)); myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string ...@@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options))); depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
depthModules.emplace_back(register_module(name, Concat(inSize))); depthModules.emplace_back(register_module(name, Concat(inSize, outSize)));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} }
......
...@@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def ...@@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st ...@@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string ...@@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string
else if (subModuleType == "CNN") else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options)); myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin ...@@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
else if (subModuleType == "CNN") else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options)); myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st ...@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options)); myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1)); myModule = register_module("myModule", Concat(1,1));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
...@@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def ...@@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else if (subModuleType == "Transformer") else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(inSize, outSize, options)); myModule = register_module("myModule", Transformer(inSize, outSize, options));
else else
......
...@@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con ...@@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options)); myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize)); myModule = register_module("myModule", Concat(inSize, outSize));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
...@@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st ...@@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU") else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options)); myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat") else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1)); myModule = register_module("myModule", Concat(1, 1));
else else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType)); util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));} } catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
...@@ -282,7 +282,7 @@ int MacaonTrain::main() ...@@ -282,7 +282,7 @@ int MacaonTrain::main()
{ {
machine.resetClassifiers(); machine.resetClassifiers();
machine.trainMode(currentEpoch == 0); machine.trainMode(currentEpoch == 0);
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters())); fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
} }
machine.resetOptimizers(); machine.resetOptimizers();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment