Skip to content
Snippets Groups Projects
Commit 05acae81 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed TODO

parent 15787f04
No related branches found
No related tags found
No related merge requests found
......@@ -153,7 +153,6 @@ bool util::doIfNameMatch(const std::regex & reg, std::string_view name, const st
return true;
}
//TODO : test this
std::string util::strip(const std::string & s)
{
std::string striped;
......
......@@ -17,10 +17,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
//TODO : check if clone is mandatory
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone();
//TODO : check if NoGradGuard does anything
torch::NoGradGuard guard;
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong);
auto prediction = machine.getClassifier()->getNN()(neuralInput);
int chosenTransition = -1;
......
......@@ -21,9 +21,8 @@ void Trainer::createDataset(SubConfig & config)
util::myThrow("No transition appliable !");
}
//TODO : check if clone is mandatory
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, at::kLong);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment