Skip to content
Snippets Groups Projects
Commit f5a30e71 authored by Franck Dary's avatar Franck Dary
Browse files

Added function so set columns of NeuralNetwork

parent 4bebd4a8
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -12,7 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module
int leftBorder{5}; int leftBorder{5};
int rightBorder{5}; int rightBorder{5};
int nbStackElements{2}; int nbStackElements{2};
std::vector<std::string> columns{"FORM", "UPOS"}; std::vector<std::string> columns{"FORM"};
protected : protected :
...@@ -25,6 +25,7 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -25,6 +25,7 @@ class NeuralNetworkImpl : public torch::nn::Module
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
std::vector<long> extractContext(Config & config, Dict & dict) const; std::vector<long> extractContext(Config & config, Dict & dict) const;
int getContextSize() const; int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
}; };
TORCH_MODULE(NeuralNetwork); TORCH_MODULE(NeuralNetwork);
......
...@@ -6,6 +6,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in ...@@ -6,6 +6,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in
setLeftBorder(leftBorder); setLeftBorder(leftBorder);
setRightBorder(rightBorder); setRightBorder(rightBorder);
setNbStackElements(nbStackElements); setNbStackElements(nbStackElements);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
......
...@@ -56,3 +56,8 @@ void NeuralNetworkImpl::setNbStackElements(int nbStackElements) ...@@ -56,3 +56,8 @@ void NeuralNetworkImpl::setNbStackElements(int nbStackElements)
this->nbStackElements = nbStackElements; this->nbStackElements = nbStackElements;
} }
void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
{
this->columns = columns;
}
...@@ -19,6 +19,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) ...@@ -19,6 +19,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
setLeftBorder(leftBorder); setLeftBorder(leftBorder);
setRightBorder(rightBorder); setRightBorder(rightBorder);
setNbStackElements(0); setNbStackElements(0);
setColumns({"FORM", "UPOS"});
} }
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment