diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index 5299d2d6b1b295c7000d60dfb4596a636aaabef3..5c8493e562ed9ea2bf50df7bb6fd5ce2983a391d 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -12,7 +12,7 @@ class NeuralNetworkImpl : public torch::nn::Module int leftBorder{5}; int rightBorder{5}; int nbStackElements{2}; - std::vector<std::string> columns{"FORM", "UPOS"}; + std::vector<std::string> columns{"FORM"}; protected : @@ -25,6 +25,7 @@ class NeuralNetworkImpl : public torch::nn::Module virtual torch::Tensor forward(torch::Tensor input) = 0; std::vector<long> extractContext(Config & config, Dict & dict) const; int getContextSize() const; + void setColumns(const std::vector<std::string> & columns); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index fd9f2b8537a6685bbca24e2d664f570160c4dfe0..81ee6252f29ac1b1f676f25557cf64da3a413331 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -6,6 +6,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); + setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index e39729bb5481b8029c9f3b98751f9f574dcf8e0d..cdb8ad3426050ddab9581713da34d42fe4394ae6 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -56,3 +56,8 @@ void NeuralNetworkImpl::setNbStackElements(int nbStackElements) this->nbStackElements = nbStackElements; } +void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns) +{ + this->columns = columns; +} + diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index 6e3c934947df108bbf9a59cf45d54b39e3fd23e8..d2d796693c88e2c019c0067b0801d3a69a0d0d2c 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -19,6 +19,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(0); + setColumns({"FORM", "UPOS"}); } torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)