diff --git a/neural_network/src/MultiMLP.cpp b/neural_network/src/MultiMLP.cpp index 6458022021ae855141059f5a4a7d1faff8de87b9..9c01f6be2deffeec4df8c3554e9f47bb83dd2d6b 100644 --- a/neural_network/src/MultiMLP.cpp +++ b/neural_network/src/MultiMLP.cpp @@ -96,7 +96,12 @@ float MultiMLP::update(FeatureModel::FeatureDescription & fd, int gold) { int id = std::stoi(split(mlp.name, '_')[1]); mlp.setBatchSize(getBatchSize()); - loss += mlp.update(fd, id == gold ? 1 : 0); + if (gold >= 0) + loss += mlp.update(fd, id == gold ? 1 : 0); + else if (id == (-1-gold)) + loss += mlp.update(fd, 0); + else + continue; trainer->update(); } catch (BatchNotFull &) @@ -121,7 +126,12 @@ float MultiMLP::getLoss(FeatureModel::FeatureDescription & fd, int gold) { int id = std::stoi(split(mlp.name, '_')[1]); mlp.setBatchSize(getBatchSize()); - loss += mlp.getLoss(fd, id == gold ? 1 : 0); + if (gold >= 0) + loss += mlp.update(fd, id == gold ? 1 : 0); + else if (id == (-1-gold)) + loss += mlp.update(fd, 0); + else + continue; } catch (BatchNotFull &) { }