From 006cdfc9073d55656ff50a89561340b03477bb18 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 30 Jul 2019 15:32:15 +0200 Subject: [PATCH] trying to fix backtrack prediction --- neural_network/src/MLPBase.cpp | 2 +- trainer/src/Trainer.cpp | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/neural_network/src/MLPBase.cpp b/neural_network/src/MLPBase.cpp index 00ce27c..18585c0 100644 --- a/neural_network/src/MLPBase.cpp +++ b/neural_network/src/MLPBase.cpp @@ -162,7 +162,7 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, const std::vector<f dynet::Expression batchedLoss; std::vector<dynet::Expression> goldExpressions; for (auto & gold : goldsContinuous) - goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({1,(unsigned int)gold.size()}), gold)); + goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({(unsigned int)gold.size()}), gold)); dynet::Expression batchedGold = dynet::concatenate_to_batch(goldExpressions); batchedLoss = dynet::sum_batches(dynet::squared_distance(output, batchedGold)); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 61032a3..38c6a35 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -400,20 +400,36 @@ void Trainer::doStepTrain() if (TI.lastActionWasPredicted[normalStateName]) { + std::string updateInfos; + if (newCost >= lastCost) { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); +// loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); + int nbActions = tm.getCurrentClassifier()->getNbActions(); + int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()); + float value = 1.0 / (nbActions-1); + std::vector<float> goldOutput(nbActions, value); + goldOutput[backIndex] = 0.0; + + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput); + + updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, bad decision"; } else { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); +//loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + int nbActions = tm.getCurrentClassifier()->getNbActions(); + int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()); + std::vector<float> goldOutput(nbActions, 0.0); + goldOutput[backIndex] = 1.0; + + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput); - if (ProgramParameters::debug) - fprintf(stderr, "Updating neural network \'%s\', gold=\'%s\'\n", tm.getCurrentClassifier()->name.c_str(), trainConfig.getCurrentStateHistory().top().c_str()); + updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, good decision"; } if (ProgramParameters::debug) - fprintf(stderr, "Updating neural network \'%s\'\n", tm.getCurrentClassifier()->name.c_str()); + fprintf(stderr, "Updating neural network \'%s\' : %s\n", tm.getCurrentClassifier()->name.c_str(), updateInfos.c_str()); TI.addTrainLoss(tm.getCurrentClassifier()->name, loss); } -- GitLab