From 05acae819cc0a894e32231a9cb9e6ae79b5f5ce2 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 7 Feb 2020 14:35:31 +0100 Subject: [PATCH] Fixed TODO --- common/src/util.cpp | 1 - decoder/src/Decoder.cpp | 5 +---- trainer/src/Trainer.cpp | 3 +-- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/common/src/util.cpp b/common/src/util.cpp index 1e25018..379056f 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -153,7 +153,6 @@ bool util::doIfNameMatch(const std::regex & reg, std::string_view name, const st return true; } -//TODO : test this std::string util::strip(const std::string & s) { std::string striped; diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 3191cbf..eea1ef7 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -17,10 +17,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize) auto context = config.extractContext(5,5,machine.getDict(config.getState())); machine.getDict(config.getState()).setState(dictState); - //TODO : check if clone is mandatory - auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(); - //TODO : check if NoGradGuard does anything - torch::NoGradGuard guard; + auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong); auto prediction = machine.getClassifier()->getNN()(neuralInput); int chosenTransition = -1; diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 059da2b..51d1a3b 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -21,9 +21,8 @@ void Trainer::createDataset(SubConfig & config) util::myThrow("No transition appliable !"); } - //TODO : check if clone is mandatory auto context = config.extractContext(5,5,machine.getDict(config.getState())); - contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, at::kLong); -- GitLab