diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 7d2a4fc6f7a03feb098147e53e69e2bbd898f8bd..5ecccaecb697c894da913cec346136bd0ef5f2da 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -35,7 +35,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto rawLetters = embeddings.narrow(1, 0, leftWindowRawInput+rightWindowRawInput+1); - auto context = embeddings.narrow(1, rawLetters.size(0), columns.size()*(1+leftBorder+rightBorder)); + auto context = embeddings.narrow(1, rawLetters.size(1), columns.size()*(1+leftBorder+rightBorder)); context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}); auto elementsEmbeddings = embeddings.narrow(1, rawLetters.size(1)+context.size(1), input.size(1)-(rawLetters.size(1)+context.size(1)));