From f5a30e7111a8585c858485f03ad1a7a0bc96a1d0 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 24 Feb 2020 19:02:14 +0100 Subject: [PATCH] Added function so set columns of NeuralNetwork --- torch_modules/include/NeuralNetwork.hpp | 3 ++- torch_modules/src/ConcatWordsNetwork.cpp | 1 + torch_modules/src/NeuralNetwork.cpp | 5 +++++ torch_modules/src/OneWordNetwork.cpp | 1 + 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 5299d2d..5c8493e 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -12,7 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module int leftBorder{5}; int rightBorder{5}; int nbStackElements{2}; - std::vector<std::string> columns{"FORM", "UPOS"}; + std::vector<std::string> columns{"FORM"}; protected : @@ -25,6 +25,7 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; std::vector<long> extractContext(Config & config, Dict & dict) const; int getContextSize() const; + void setColumns(const std::vector<std::string> & columns); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index fd9f2b8..81ee625 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -6,6 +6,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); + setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index e39729b..cdb8ad3 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -56,3 +56,8 @@ void NeuralNetworkImpl::setNbStackElements(int nbStackElements) this->nbStackElements = nbStackElements; } +void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns) +{ + this->columns = columns; +} + diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index 6e3c934..d2d7966 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -19,6 +19,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(0); + setColumns({"FORM", "UPOS"}); } torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input) -- GitLab