Commit 0652f587 authored by Franck Dary's avatar Franck Dary
Browse files

Added dimmension reduction in Concat module

parent 5800a6f3
......@@ -101,6 +101,7 @@ int Classifier::getNbParameters() const
int nbParameters = 0;
for (auto & t : nn->parameters())
if (t.requires_grad())
nbParameters += torch::numel(t);
return nbParameters;
......
......@@ -9,10 +9,12 @@ class ConcatImpl : public MyModule
private :
int inputSize;
int outputSize;
torch::nn::Linear dimReduce{nullptr};
public :
ConcatImpl(int inputSize);
ConcatImpl(int inputSize, int outputSize);
torch::Tensor forward(torch::Tensor input);
int getOutputSize(int sequenceLength);
};
......
#include "Concat.hpp"
ConcatImpl::ConcatImpl(int inputSize) : inputSize(inputSize)
ConcatImpl::ConcatImpl(int inputSize, int outputSize) : inputSize(inputSize), outputSize(outputSize)
{
dimReduce = register_module("dimReduce", torch::nn::Linear(inputSize, outputSize));
}
torch::Tensor ConcatImpl::forward(torch::Tensor input)
{
return input.view({input.size(0), -1});
return dimReduce(input).view({input.size(0), -1});
}
int ConcatImpl::getOutputSize(int sequenceLength)
{
return sequenceLength * inputSize;
return sequenceLength * outputSize;
}
......@@ -42,7 +42,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(columns.size()*inSize, outSize, options));
else
......
......@@ -49,7 +49,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string &
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(columns.size()*inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -39,7 +39,7 @@ DepthLayerTreeEmbeddingModuleImpl::DepthLayerTreeEmbeddingModuleImpl(std::string
else if (subModuleType == "GRU")
depthModules.emplace_back(register_module(name, GRU(columns.size()*inSize, outSize, options)));
else if (subModuleType == "Concat")
depthModules.emplace_back(register_module(name, Concat(inSize)));
depthModules.emplace_back(register_module(name, Concat(inSize, outSize)));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
}
......
......@@ -39,7 +39,7 @@ DistanceModuleImpl::DistanceModuleImpl(std::string name, const std::string & def
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -35,7 +35,7 @@ FocusedColumnModuleImpl::FocusedColumnModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -29,7 +29,7 @@ HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -29,7 +29,7 @@ HistoryModuleImpl::HistoryModuleImpl(std::string name, const std::string & defin
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -35,7 +35,7 @@ NumericColumnModuleImpl::NumericColumnModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
myModule = register_module("myModule", Concat(1,1));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
......@@ -31,7 +31,7 @@ RawInputModuleImpl::RawInputModuleImpl(std::string name, const std::string & def
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else if (subModuleType == "Transformer")
myModule = register_module("myModule", Transformer(inSize, outSize, options));
else
......
......@@ -27,7 +27,7 @@ SplitTransModuleImpl::SplitTransModuleImpl(std::string name, int maxNbTrans, con
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
myModule = register_module("myModule", Concat(inSize, outSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
......
......@@ -31,7 +31,7 @@ UppercaseRateModuleImpl::UppercaseRateModuleImpl(std::string name, const std::st
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(1, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(1));
myModule = register_module("myModule", Concat(1, 1));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
......
......@@ -282,7 +282,7 @@ int MacaonTrain::main()
{
machine.resetClassifiers();
machine.trainMode(currentEpoch == 0);
fmt::print(stderr, "[{}] Model has {} parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
fmt::print(stderr, "[{}] Model has {} trainable parameters\n", util::getTime(), util::int2HumanStr(machine.getNbParameters()));
}
machine.resetOptimizers();
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment