Skip to content
Snippets Groups Projects
Commit b75be56d authored by Franck Dary's avatar Franck Dary
Browse files

maxNbEmbeddings is now part of NeuralNetwork

parent 07748169
No related branches found
No related tags found
No related merge requests found
......@@ -8,8 +8,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
......
......@@ -7,8 +7,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
......
......@@ -13,6 +13,8 @@ class NeuralNetworkImpl : public torch::nn::Module
protected :
static constexpr int maxNbEmbeddings = 150000;
std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{};
......
......@@ -9,7 +9,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<
setStackContext(stackContext);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
......
......@@ -14,7 +14,7 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(treeEmbeddingsSize*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
vectorBiLSTM = register_module("vector_lstm", torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize*columns.size(), lstmOutputSize).batch_first(true).bidirectional(true)));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment