diff --git a/neural_network/src/MLPBase.cpp b/neural_network/src/MLPBase.cpp index 00ce27cf4d3526f1f7c052999275b8dd1fa9e379..18585c05fff8726dc7098edf1ad25282c86ca254 100644 --- a/neural_network/src/MLPBase.cpp +++ b/neural_network/src/MLPBase.cpp @@ -162,7 +162,7 @@ float MLPBase::update(FeatureModel::FeatureDescription & fd, const std::vector<f dynet::Expression batchedLoss; std::vector<dynet::Expression> goldExpressions; for (auto & gold : goldsContinuous) - goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({1,(unsigned int)gold.size()}), gold)); + goldExpressions.emplace_back(dynet::input(cg, dynet::Dim({(unsigned int)gold.size()}), gold)); dynet::Expression batchedGold = dynet::concatenate_to_batch(goldExpressions); batchedLoss = dynet::sum_batches(dynet::squared_distance(output, batchedGold)); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 61032a35bedc17a0e8584eeb6e2ac360af7975e6..38c6a35a41b03a4287aa3bc1f98f47b490aa8a52 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -400,20 +400,36 @@ void Trainer::doStepTrain() if (TI.lastActionWasPredicted[normalStateName]) { + std::string updateInfos; + if (newCost >= lastCost) { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); +// loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex("EPSILON")); + int nbActions = tm.getCurrentClassifier()->getNbActions(); + int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()); + float value = 1.0 / (nbActions-1); + std::vector<float> goldOutput(nbActions, value); + goldOutput[backIndex] = 0.0; + + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput); + + updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, bad decision"; } else { - loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); +//loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top())); + int nbActions = tm.getCurrentClassifier()->getNbActions(); + int backIndex = tm.getCurrentClassifier()->getActionIndex(trainConfig.getCurrentStateHistory().top()); + std::vector<float> goldOutput(nbActions, 0.0); + goldOutput[backIndex] = 1.0; + + loss = tm.getCurrentClassifier()->trainOnExample(pendingFD[tm.getCurrentClassifier()->name], goldOutput); - if (ProgramParameters::debug) - fprintf(stderr, "Updating neural network \'%s\', gold=\'%s\'\n", tm.getCurrentClassifier()->name.c_str(), trainConfig.getCurrentStateHistory().top().c_str()); + updateInfos = "predicted : <"+trainConfig.getCurrentStateHistory().top()+">, good decision"; } if (ProgramParameters::debug) - fprintf(stderr, "Updating neural network \'%s\'\n", tm.getCurrentClassifier()->name.c_str()); + fprintf(stderr, "Updating neural network \'%s\' : %s\n", tm.getCurrentClassifier()->name.c_str(), updateInfos.c_str()); TI.addTrainLoss(tm.getCurrentClassifier()->name, loss); }