diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index a036e7622940d5eb22f2c8f820df0b0063f19d0c..6fb985ac482f88856cc8b2fa6bc07f4971e79546 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl { private : - static constexpr int maxNbEmbeddings = 50000; - int unknownValueThreshold; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index e276eb23f786bc3e80b7004cb67f893b347fe5cf..0e1ad1c91bc17cecffa3d47a218837f4c2506600 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl { private : - static constexpr int maxNbEmbeddings = 50000; - int unknownValueThreshold; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 8ffa7349d35b2b50816f5574a13a0e6c096c2e9c..be1846bacc0b356b5b4eed9aa67e2caf46f2a7ad 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module protected : + static constexpr int maxNbEmbeddings = 150000; + std::vector<std::string> columns{"FORM"}; std::vector<int> bufferContext{-3,-2,-1,0,1}; std::vector<int> stackContext{}; diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index b03b8493b0af20679e70cea6cd73af59e4658588..2331d59a61e1f8ee4bd7679bc009125c61e70775 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector< setStackContext(stackContext); setColumns({"FORM", "UPOS"}); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); dropout = register_module("dropout", torch::nn::Dropout(0.3)); diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp index a9a346f635a2f35f10aea8ff6743a320bbc537f6..e4f3fc215aef10c1cfbaa05371b094e980adf6cc 100644 --- a/torch_modules/src/RLTNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setStackContext({}); setColumns({"FORM", "UPOS"}); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));