From bfd53a1705e2c1f9b70d146fd55a0e30d3137671 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 24 Apr 2020 19:00:18 +0200 Subject: [PATCH] Updated syntax for pytorch 1.5 --- reading_machine/src/Classifier.cpp | 2 +- torch_modules/src/LSTM.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index d8418c9..d9ead9a 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -101,7 +101,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) if (splited.size() != 6 or (splited.back() != "false" and splited.back() != "true")) util::myThrow(expected); - optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").beta1(std::stof(splited[1])).beta2(std::stof(splited[2])).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); + optimizer.reset(new torch::optim::Adam(getNN()->parameters(), torch::optim::AdamOptions(std::stof(splited[0])).amsgrad(splited.back() == "true").betas({std::stof(splited[1]),std::stof(splited[2])}).eps(std::stof(splited[3])).weight_decay(std::stof(splited[4])))); } else util::myThrow(expected); diff --git a/torch_modules/src/LSTM.cpp b/torch_modules/src/LSTM.cpp index b8f8e7f..af89a3d 100644 --- a/torch_modules/src/LSTM.cpp +++ b/torch_modules/src/LSTM.cpp @@ -5,7 +5,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputA auto lstmOptions = torch::nn::LSTMOptions(inputSize, outputSize) .batch_first(std::get<0>(options)) .bidirectional(std::get<1>(options)) - .layers(std::get<2>(options)) + .num_layers(std::get<2>(options)) .dropout(std::get<3>(options)); lstm = register_module("lstm", torch::nn::LSTM(lstmOptions)); @@ -13,7 +13,7 @@ LSTMImpl::LSTMImpl(int inputSize, int outputSize, LSTMOptions options) : outputA torch::Tensor LSTMImpl::forward(torch::Tensor input) { - auto lstmOut = lstm(input).output; + auto lstmOut = std::get<0>(lstm(input)); if (outputAll) return lstmOut.reshape({lstmOut.size(0), -1}); -- GitLab