Skip to content
Snippets Groups Projects
Commit b75be56d authored by Franck Dary's avatar Franck Dary
Browse files

maxNbEmbeddings is now part of NeuralNetwork

parent 07748169
No related branches found
No related tags found
No related merge requests found
...@@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold; int unknownValueThreshold;
std::vector<std::string> focusedColumns; std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements; std::vector<int> maxNbElements;
......
...@@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl ...@@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold; int unknownValueThreshold;
std::vector<std::string> focusedColumns; std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements; std::vector<int> maxNbElements;
......
...@@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module
protected : protected :
static constexpr int maxNbEmbeddings = 150000;
std::vector<std::string> columns{"FORM"}; std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1}; std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{}; std::vector<int> stackContext{};
......
...@@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector< ...@@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<
setStackContext(stackContext); setStackContext(stackContext);
setColumns({"FORM", "UPOS"}); setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize)); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3)); dropout = register_module("dropout", torch::nn::Dropout(0.3));
......
...@@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setStackContext({}); setStackContext({});
setColumns({"FORM", "UPOS"}); setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true))); vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment