From 048c959c91b3ca37086a50ff3361b570a747f68f Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 22 Jun 2020 22:58:31 +0200 Subject: [PATCH] Added custom hinge loss and added explorationThreshold program argument --- torch_modules/include/CustomHingeLoss.hpp | 13 ++++++++ torch_modules/src/CustomHingeLoss.cpp | 17 +++++++++++ trainer/include/Trainer.hpp | 7 +++-- trainer/src/MacaonTrain.cpp | 9 ++++-- trainer/src/Trainer.cpp | 36 ++++++++++++++++------- 5 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 torch_modules/include/CustomHingeLoss.hpp create mode 100644 torch_modules/src/CustomHingeLoss.cpp diff --git a/torch_modules/include/CustomHingeLoss.hpp b/torch_modules/include/CustomHingeLoss.hpp new file mode 100644 index 0000000..e30e2b8 --- /dev/null +++ b/torch_modules/include/CustomHingeLoss.hpp @@ -0,0 +1,13 @@ +#ifndef CUSTOMHINGELOSS__H +#define CUSTOMHINGELOSS__H + +#include <torch/torch.h> + +class CustomHingeLoss +{ + public : + + torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold); +}; + +#endif diff --git a/torch_modules/src/CustomHingeLoss.cpp b/torch_modules/src/CustomHingeLoss.cpp new file mode 100644 index 0000000..bc3fc85 --- /dev/null +++ b/torch_modules/src/CustomHingeLoss.cpp @@ -0,0 +1,17 @@ +#include "CustomHingeLoss.hpp" + +torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold) +{ + torch::Tensor loss = torch::zeros(1); + + for (unsigned int i = 0; i < prediction.size(0); i++) + { + loss += torch::relu(1 - torch::max(gold[i]*prediction[i]) + + torch::max((1-gold[i])*prediction[i])); + } + + loss /= prediction.size(0); + + return loss; +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index ad14ef6..25d48f4 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -4,12 +4,13 @@ #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" +#include "CustomHingeLoss.hpp" class LossFunction { private : - std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct; + std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct; public : @@ -69,13 +70,13 @@ class Trainer private : - void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); + void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); public : Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName); - void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); + void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); float epoch(bool printAdvancement); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 9db5c5a..b999e2c 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -33,12 +33,14 @@ po::options_description MacaonTrain::getOptionsDescription() "Number of training epochs") ("batchSize", po::value<int>()->default_value(64), "Number of examples per batch") + ("explorationThreshold", po::value<float>()->default_value(0.1), + "Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset") ("machine", po::value<std::string>()->default_value(""), "Reading machine file content") ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), "Description of what should happen during training") ("loss", po::value<std::string>()->default_value("CrossEntropy"), - "Loss function to use during training : CrossEntropy | bce | mse") + "Loss function to use during training : CrossEntropy | bce | mse | hinge") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -128,6 +130,7 @@ int MacaonTrain::main() auto machineContent = variables["machine"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>(); + auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto trainStrategy = parseTrainStrategy(trainStrategyStr); @@ -211,11 +214,11 @@ int MacaonTrain::main() if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) { machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); - trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); if (!computeDevScore) { machine.setDictsState(Dict::State::Closed); - trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); + trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold); } } if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 90b65b7..74d6d13 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -9,6 +9,8 @@ LossFunction::LossFunction(std::string name) fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "mse") fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "hinge") + fct = CustomHingeLoss(); else util::myThrow(fmt::format("unknown loss function name '{}'", name)); } @@ -23,6 +25,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); if (index == 2) return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); + if (index == 3) + return std::get<3>(fct)(torch::softmax(prediction, 1), gold); util::myThrow("loss is not defined"); return torch::Tensor(); @@ -38,7 +42,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: gold[0] = goldIndexes.at(0); return gold; } - if (index == 1 or index == 2) + if (index == 1 or index == 2 or index == 3) { auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); for (auto goldIndex : goldIndexes) @@ -66,18 +70,18 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir) devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } -void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) +void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); - extractExamples(config, debug, dir, epoch, dynamicOracle); + extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold); machine.saveDicts(); } -void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) +void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { torch::AutoGradMode useGrad(false); @@ -129,22 +133,32 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); + auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze(); - int chosenTransition = -1; float bestScore = std::numeric_limits<float>::min(); + std::vector<int> candidates; for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); - if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) - { - chosenTransition = i; + if (score > bestScore and appliableTransitions[i]) bestScore = score; - } } - transition = machine.getTransitionSet().getTransition(chosenTransition); + for (unsigned int i = 0; i < prediction.size(0); i++) + { + float score = prediction[i].item<float>(); + if (appliableTransitions[i] and bestScore - score <= explorationThreshold) + candidates.emplace_back(i); + } + + if (candidates.size() != 1) + { + fmt::print(stderr, "nbCand = {}\n", candidates.size()); + std::cerr << prediction << std::endl; + } + + transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]); } else { -- GitLab