diff --git a/torch_modules/src/RTLSTMNetwork.cpp b/torch_modules/src/RTLSTMNetwork.cpp index ef3bfa92c7d2fc7eaa84f629622a5dfb3cbcdf1f..6cc8f70bab61d2924ed6c3c022e63303e5319fc0 100644 --- a/torch_modules/src/RTLSTMNetwork.cpp +++ b/torch_modules/src/RTLSTMNetwork.cpp @@ -2,12 +2,13 @@ RTLSTMNetworkImpl::RTLSTMNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements) { - constexpr int embeddingsSize = 100; + constexpr int embeddingsSize = 30; constexpr int lstmOutputSize = 500; constexpr int hiddenSize = 500; setLeftBorder(leftBorder); setRightBorder(rightBorder); setNbStackElements(nbStackElements); + setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(lstmOutputSize, hiddenSize));