diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 164de44db0df1b8fca52ac74e1b2ee4f015ec719..c0a717c14ed245a84ea908ccc95685c5c6260a41 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -29,7 +29,7 @@ ContextualModuleImpl::ContextualModuleImpl(std::string name, const std::string & .bidirectional(std::stoi(subModuleArguments[0])) .num_layers(std::stoi(subModuleArguments[1])) .dropout(std::stof(subModuleArguments[2])) - .complete(std::stoi(subModuleArguments[3])); + .complete(true); for (auto & target : util::split(sm.str(8), ' ')) { diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index d84af461c789d1f295640d26bfddd18208f1c89f..c394321a7aec86cc2f9e3bdc5c3cd0d926b8c967 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -13,7 +13,12 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, ModuleOptions options) : outpu torch::Tensor LSTMImpl::forward(torch::Tensor input) { - return std::get<0>(lstm(input)); + auto res = std::get<0>(lstm(input)); + + if (outputAll) + return res; + + return torch::cat({torch::narrow(res, 1, 0, 1), torch::narrow(res, 1, res.size(1)-1, 1)}, 1); } int LSTMImpl::getOutputSize(int sequenceLength)