diff --git a/torch_modules/src/ConcatWordsNetwork.cpp b/torch_modules/src/ConcatWordsNetwork.cpp index 9d7d0c44db06c3cea838821e0323a5d9ec4307fb..b72694c4fa296a2215f8c3d65991977a2e69e6dc 100644 --- a/torch_modules/src/ConcatWordsNetwork.cpp +++ b/torch_modules/src/ConcatWordsNetwork.cpp @@ -7,7 +7,7 @@ ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, in setRightBorder(rightBorder); setNbStackElements(nbStackElements); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize).sparse(true))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500)); linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs)); } diff --git a/torch_modules/src/OneWordNetwork.cpp b/torch_modules/src/OneWordNetwork.cpp index c2a11db6c3fbc8e14545a0c04ce4e17170b8b18e..6e3c934947df108bbf9a59cf45d54b39e3fd23e8 100644 --- a/torch_modules/src/OneWordNetwork.cpp +++ b/torch_modules/src/OneWordNetwork.cpp @@ -4,7 +4,7 @@ OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 30; - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize))); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); int leftBorder = 0;