Commit b13669bd authored by Franck Dary's avatar Franck Dary
Browse files

Added program arguments : scaleGrad and maxNorm

parent 397e390f
......@@ -9,12 +9,13 @@
#include "LSTM.hpp"
#include "Concat.hpp"
#include "Transformer.hpp"
#include "WordEmbeddings.hpp"
class ContextModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions;
......
......@@ -8,12 +8,13 @@
#include "GRU.hpp"
#include "LSTM.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class ContextualModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<std::string> columns;
std::vector<std::function<std::string(const std::string &)>> functions;
......
......@@ -7,6 +7,7 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class DepthLayerTreeEmbeddingModuleImpl : public Submodule
{
......@@ -16,7 +17,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::vector<std::string> columns;
std::vector<int> focusedBuffer;
std::vector<int> focusedStack;
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::vector<std::shared_ptr<MyModule>> depthModules;
int inSize;
......
......@@ -7,12 +7,13 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class DistanceModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> fromBuffer, fromStack;
std::vector<int> toBuffer, toStack;
......
......@@ -7,12 +7,13 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class FocusedColumnModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
std::vector<int> focusedBuffer, focusedStack;
std::string column;
......
......@@ -8,12 +8,13 @@
#include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class HistoryModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbElements;
int inSize;
......
......@@ -7,12 +7,13 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class RawInputModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int leftWindow, rightWindow;
int inSize;
......
......@@ -7,12 +7,13 @@
#include "LSTM.hpp"
#include "GRU.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class SplitTransModuleImpl : public Submodule
{
private :
torch::nn::Embedding wordEmbeddings{nullptr};
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbTrans;
int inSize;
......
......@@ -6,12 +6,13 @@
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "WordEmbeddings.hpp"
class StateNameModuleImpl : public Submodule
{
private :
torch::nn::Embedding embeddings{nullptr};
WordEmbeddings embeddings{nullptr};
int outSize;
public :
......
......@@ -16,7 +16,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public :
void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path, std::string prefix);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding embeddings, std::filesystem::path path, std::string prefix);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
......
#ifndef WORDEMBEDDINGS__H
#define WORDEMBEDDINGS__H
#include "torch/torch.h"
class WordEmbeddingsImpl : public torch::nn::Module
{
private :
static bool scaleGradByFreq;
static float maxNorm;
private :
torch::nn::Embedding embeddings{nullptr};
public :
static void setScaleGradByFreq(bool scaleGradByFreq);
static void setMaxNorm(float maxNorm);
WordEmbeddingsImpl(std::size_t vocab, std::size_t dim);
torch::nn::Embedding get();
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(WordEmbeddings);
#endif
......@@ -161,12 +161,12 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
}
}
......@@ -210,13 +210,13 @@ torch::Tensor ContextualModuleImpl::forward(torch::Tensor input)
void ContextualModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
}
}
......@@ -126,6 +126,6 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -111,6 +111,6 @@ void DistanceModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void DistanceModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -156,12 +156,12 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
void FocusedColumnModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
auto pathes = util::split(w2vFiles.string(), ' ');
for (auto & p : pathes)
{
auto splited = util::split(p, ',');
loadPretrainedW2vEmbeddings(wordEmbeddings, path / splited[1], splited[0]);
loadPretrainedW2vEmbeddings(wordEmbeddings->get(), path / splited[1], splited[0]);
}
}
......@@ -69,6 +69,6 @@ void HistoryModuleImpl::addToContext(std::vector<std::vector<long>> & context, c
void HistoryModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -78,6 +78,6 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void RawInputModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -65,6 +65,6 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
void SplitTransModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -38,6 +38,6 @@ void StateNameModuleImpl::addToContext(std::vector<std::vector<long>> & context,
void StateNameModuleImpl::registerEmbeddings()
{
embeddings = register_module("embeddings", torch::nn::Embedding(getDict().size(), outSize));
embeddings = register_module("embeddings", WordEmbeddings(getDict().size(), outSize));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment