From 05acae819cc0a894e32231a9cb9e6ae79b5f5ce2 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Fri, 7 Feb 2020 14:35:31 +0100
Subject: [PATCH] Fixed TODO

---
 common/src/util.cpp     | 1 -
 decoder/src/Decoder.cpp | 5 +----
 trainer/src/Trainer.cpp | 3 +--
 3 files changed, 2 insertions(+), 7 deletions(-)

diff --git a/common/src/util.cpp b/common/src/util.cpp
index 1e25018..379056f 100644
--- a/common/src/util.cpp
+++ b/common/src/util.cpp
@@ -153,7 +153,6 @@ bool util::doIfNameMatch(const std::regex & reg, std::string_view name, const st
   return true;
 }
 
-//TODO : test this
 std::string util::strip(const std::string & s)
 {
   std::string striped;
diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp
index 3191cbf..eea1ef7 100644
--- a/decoder/src/Decoder.cpp
+++ b/decoder/src/Decoder.cpp
@@ -17,10 +17,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
     auto context = config.extractContext(5,5,machine.getDict(config.getState()));
     machine.getDict(config.getState()).setState(dictState);
 
-    //TODO : check if clone is mandatory
-    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone();
-    //TODO : check if NoGradGuard does anything
-    torch::NoGradGuard guard;
+    auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
     auto prediction = machine.getClassifier()->getNN()(neuralInput);
 
     int chosenTransition = -1;
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 059da2b..51d1a3b 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -21,9 +21,8 @@ void Trainer::createDataset(SubConfig & config)
       util::myThrow("No transition appliable !");
     }
 
-    //TODO : check if clone is mandatory
     auto context = config.extractContext(5,5,machine.getDict(config.getState()));
-    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
+    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
     auto gold = torch::zeros(1, at::kLong);
-- 
GitLab