From b75be56d9e96f835fcd81dae4e1aefdaaf3f18a0 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 1 Apr 2020 14:05:15 +0200
Subject: [PATCH] maxNbEmbeddings is now part of NeuralNetwork

---
 torch_modules/include/CNNNetwork.hpp     | 2 --
 torch_modules/include/LSTMNetwork.hpp    | 2 --
 torch_modules/include/NeuralNetwork.hpp  | 2 ++
 torch_modules/src/ConcatWordsNetwork.cpp | 2 +-
 torch_modules/src/RLTNetwork.cpp         | 2 +-
 5 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
index a036e76..6fb985a 100644
--- a/torch_modules/include/CNNNetwork.hpp
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  static constexpr int maxNbEmbeddings = 50000;
-
   int unknownValueThreshold;
   std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp
index e276eb2..0e1ad1c 100644
--- a/torch_modules/include/LSTMNetwork.hpp
+++ b/torch_modules/include/LSTMNetwork.hpp
@@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
 {
   private :
 
-  static constexpr int maxNbEmbeddings = 50000;
-
   int unknownValueThreshold;
   std::vector<std::string> focusedColumns;
   std::vector<int> maxNbElements;
diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp
index 8ffa734..be1846b 100644
--- a/torch_modules/include/NeuralNetwork.hpp
+++ b/torch_modules/include/NeuralNetwork.hpp
@@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module
 
   protected : 
 
+  static constexpr int maxNbEmbeddings = 150000;
+
   std::vector<std::string> columns{"FORM"};
   std::vector<int> bufferContext{-3,-2,-1,0,1};
   std::vector<int> stackContext{};
diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp
index b03b849..2331d59 100644
--- a/torch_modules/src/ConcatWordsNetwork.cpp
+++ b/torch_modules/src/ConcatWordsNetwork.cpp
@@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<
   setStackContext(stackContext);
   setColumns({"FORM", "UPOS"});
 
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
   dropout = register_module("dropout", torch::nn::Dropout(0.3));
diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp
index a9a346f..e4f3fc2 100644
--- a/torch_modules/src/RLTNetwork.cpp
+++ b/torch_modules/src/RLTNetwork.cpp
@@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
   setStackContext({});
   setColumns({"FORM", "UPOS"});
 
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
   linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
   linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
   vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
-- 
GitLab