Commit b13669bd authored by Franck Dary's avatar Franck Dary
Browse files

Added program arguments : scaleGrad and maxNorm

parent 397e390f
...@@ -9,12 +9,13 @@ ...@@ -9,12 +9,13 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "Transformer.hpp" #include "Transformer.hpp"
#include "WordEmbeddings.hpp"
class ContextModuleImpl : public Submodule class ContextModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns; std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions; std::vector<std::function<std::string(const std::string &)>> functions;
......
...@@ -8,12 +8,13 @@ ...@@ -8,12 +8,13 @@
#include "GRU.hpp" #include "GRU.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class ContextualModuleImpl : public Submodule class ContextualModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns; std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions; std::vector<std::function<std::string(const std::string &)>> functions;
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class DepthLayerTreeEmbeddingModuleImpl : public Submodule class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{ {
...@@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule ...@@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::vector<std::string> columns; std::vector<std::string> columns;
std::vector<int> focusedBuffer; std::vector<int> focusedBuffer;
std::vector<int> focusedStack; std::vector<int> focusedStack;
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::vector<std::shared_ptr<MyModule>> depthModules; std::vector<std::shared_ptr<MyModule>> depthModules;
int inSize; int inSize;
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class DistanceModuleImpl : public Submodule class DistanceModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> fromBuffer, fromStack; std::vector<int> fromBuffer, fromStack;
std::vector<int> toBuffer, toStack; std::vector<int> toBuffer, toStack;
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class FocusedColumnModuleImpl : public Submodule class FocusedColumnModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> focusedBuffer, focusedStack; std::vector<int> focusedBuffer, focusedStack;
std::string column; std::string column;
......
...@@ -8,12 +8,13 @@ ...@@ -8,12 +8,13 @@
#include "GRU.hpp" #include "GRU.hpp"
#include "CNN.hpp" #include "CNN.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class HistoryModuleImpl : public Submodule class HistoryModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
int maxNbElements; int maxNbElements;
int inSize; int inSize;
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class RawInputModuleImpl : public Submodule class RawInputModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
int leftWindow, rightWindow; int leftWindow, rightWindow;
int inSize; int inSize;
......
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "Concat.hpp" #include "Concat.hpp"
#include "WordEmbeddings.hpp"
class SplitTransModuleImpl : public Submodule class SplitTransModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding wordEmbeddings{nullptr}; WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr}; std::shared_ptr<MyModule> myModule{nullptr};
int maxNbTrans; int maxNbTrans;
int inSize; int inSize;
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
#include "MyModule.hpp" #include "MyModule.hpp"
#include "LSTM.hpp" #include "LSTM.hpp"
#include "GRU.hpp" #include "GRU.hpp"
#include "WordEmbeddings.hpp"
class StateNameModuleImpl : public Submodule class StateNameModuleImpl : public Submodule
{ {
private : private :
torch::nn::Embedding embeddings{nullptr}; WordEmbeddings embeddings{nullptr};
int outSize; int outSize;
public : public :
......
...@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde ...@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public : public :
void setFirstInputIndex(std::size_t firstInputIndex); void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix); void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0; virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0; virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
......
#ifndef WORDEMBEDDINGS__H
#define WORDEMBEDDINGS__H
#include "torch/torch.h"
class WordEmbeddingsImpl : public torch::nn::Module
{
private :
static bool scaleGradByFreq;
static float maxNorm;
private :
torch::nn::Embedding embeddings{nullptr};
public :
static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm);
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get();
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(WordEmbeddings);
#endif
...@@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input) ...@@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings() void ContextModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
} }
} }
...@@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input) ...@@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings() void ContextualModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
} }
} }
...@@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon ...@@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings() void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void DistanceModuleImpl::registerEmbeddings() void DistanceModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont ...@@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings() void FocusedColumnModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' '); auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes) for (auto & p : pathes)
{ {
auto splited = util::split(p, ','); auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]); loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
} }
} }
...@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c ...@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
void HistoryModuleImpl::registerEmbeddings() void HistoryModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void RawInputModuleImpl::registerEmbeddings() void RawInputModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context ...@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
void SplitTransModuleImpl::registerEmbeddings() void SplitTransModuleImpl::registerEmbeddings()
{ {
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize))); wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
} }
...@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context, ...@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void StateNameModuleImpl::registerEmbeddings() void StateNameModuleImpl::registerEmbeddings()
{ {
embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize)); embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
} }
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment