Skip to content
Snippets Groups Projects
Commit 36a24359 authored by Franck Dary's avatar Franck Dary
Browse files

Revert "NE COMPILE PAS ENCORE : tentative de faire du minbatching"

This reverts commit 1cab6cf4.
parent 1cab6cf4
No related branches found
No related tags found
No related merge requests found
......@@ -199,41 +199,6 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd)
void MLP::update(FeatureModel::FeatureDescription & fd, int gold)
{
static std::vector<FeatureModel::FeatureDescription> batch;
static std::vector<float> golds;
batch.emplace_back(fd);
golds.emplace_back(gold);
if (batch.size() >= 500)
{
dynet::ComputationGraph cg;
std::vector<dynet::Expression> inputs;
for (auto & it : batch)
{
std::vector<dynet::Expression> expressions;
for (auto & featValue : it.values)
expressions.emplace_back(featValue2Expression(cg, featValue));
dynet::Expression input = dynet::concatenate(expressions);
inputs.emplace_back(input);
// dynet::Expression output = run(cg, input);
// losses.emplace_back(pickneglogsoftmax(output, it.second));
}
dynet::Expression batchedInput = dynet::concatenate_to_batch(inputs);
dynet::Expression batchedGolds = dynet::input(cg, dynet::Dim({1},golds.size()), golds);
//ici faire neglostsoftmax batched
dynet::Expression batchedLoss = pickneglogsoftmax(batchedInput, batchedGolds);
cg.backward(batchedLoss);
trainer->update();
batch.clear();
golds.clear();
}
/*
dynet::ComputationGraph cg;
std::vector<dynet::Expression> expressions;
......@@ -247,7 +212,6 @@ void MLP::update(FeatureModel::FeatureDescription & fd, int gold)
cg.backward(loss);
trainer->update();
*/
}
dynet::DynetParams & MLP::getDefaultParams()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment