From b75be56d9e96f835fcd81dae4e1aefdaaf3f18a0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 1 Apr 2020 14:05:15 +0200 Subject: [PATCH] maxNbEmbeddings is now part of NeuralNetwork --- torch_modules/include/CNNNetwork.hpp | 2 -- torch_modules/include/LSTMNetwork.hpp | 2 -- torch_modules/include/NeuralNetwork.hpp | 2 ++ torch_modules/src/ConcatWordsNetwork.cpp | 2 +- torch_modules/src/RLTNetwork.cpp | 2 +- 5 files changed, 4 insertions(+), 6 deletions(-) diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index a036e76..6fb985a 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl { private : - static constexpr int maxNbEmbeddings = 50000; - int unknownValueThreshold; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; diff --git a/torch_modules/include/LSTMNetwork.hpp b/torch_modules/include/LSTMNetwork.hpp index e276eb2..0e1ad1c 100644 --- a/torch_modules/include/LSTMNetwork.hpp +++ b/torch_modules/include/LSTMNetwork.hpp @@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl { private : - static constexpr int maxNbEmbeddings = 50000; - int unknownValueThreshold; std::vector<std::string> focusedColumns; std::vector<int> maxNbElements; diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 8ffa734..be1846b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module protected : + static constexpr int maxNbEmbeddings = 150000; + std::vector<std::string> columns{"FORM"}; std::vector<int> bufferContext{-3,-2,-1,0,1}; std::vector<int> stackContext{}; diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index b03b849..2331d59 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector< setStackContext(stackContext); setColumns({"FORM", "UPOS"}); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); dropout = register_module("dropout", torch::nn::Dropout(0.3)); diff --git a/torch_modules/src/RLTNetwork.cpp b/torch_modules/src/RLTNetwork.cpp index a9a346f..e4f3fc2 100644 --- a/torch_modules/src/RLTNetwork.cpp +++ b/torch_modules/src/RLTNetwork.cpp @@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setStackContext({}); setColumns({"FORM", "UPOS"}); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true))); -- GitLab