Skip to content
Snippets Groups Projects
Commit 048c959c authored by Franck Dary's avatar Franck Dary
Browse files

Added custom hinge loss and added explorationThreshold program argument

parent 8ec956e6
No related branches found
No related tags found
No related merge requests found
#ifndef CUSTOMHINGELOSS__H
#define CUSTOMHINGELOSS__H
#include <torch/torch.h>
class CustomHingeLoss
{
public :
torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold);
};
#endif
#include "CustomHingeLoss.hpp"
torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
torch::Tensor loss = torch::zeros(1);
for (unsigned int i = 0; i < prediction.size(0); i++)
{
loss += torch::relu(1 - torch::max(gold[i]*prediction[i])
+ torch::max((1-gold[i])*prediction[i]));
}
loss /= prediction.size(0);
return loss;
}
......@@ -4,12 +4,13 @@
#include "ReadingMachine.hpp"
#include "ConfigDataset.hpp"
#include "SubConfig.hpp"
#include "CustomHingeLoss.hpp"
class LossFunction
{
private :
std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct;
std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct;
public :
......@@ -69,13 +70,13 @@ class Trainer
private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
public :
Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir);
float epoch(bool printAdvancement);
......
......@@ -33,12 +33,14 @@ po::options_description MacaonTrain::getOptionsDescription()
"Number of training epochs")
("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch")
("explorationThreshold", po::value<float>()->default_value(0.1),
"Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
("machine", po::value<std::string>()->default_value(""),
"Reading machine file content")
("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
"Description of what should happen during training")
("loss", po::value<std::string>()->default_value("CrossEntropy"),
"Loss function to use during training : CrossEntropy | bce | mse")
"Loss function to use during training : CrossEntropy | bce | mse | hinge")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -128,6 +130,7 @@ int MacaonTrain::main()
auto machineContent = variables["machine"].as<std::string>();
auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto lossFunction = variables["loss"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr);
......@@ -211,11 +214,11 @@ int MacaonTrain::main()
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
{
machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
if (!computeDevScore)
{
machine.setDictsState(Dict::State::Closed);
trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic));
trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
}
}
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
......
......@@ -9,6 +9,8 @@ LossFunction::LossFunction(std::string name)
fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
else if (util::lower(name) == "mse")
fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
else if (util::lower(name) == "hinge")
fct = CustomHingeLoss();
else
util::myThrow(fmt::format("unknown loss function name '{}'", name));
}
......@@ -23,6 +25,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
if (index == 2)
return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
if (index == 3)
return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
util::myThrow("loss is not defined");
return torch::Tensor();
......@@ -38,7 +42,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
gold[0] = goldIndexes.at(0);
return gold;
}
if (index == 1 or index == 2)
if (index == 1 or index == 2 or index == 3)
{
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
for (auto goldIndex : goldIndexes)
......@@ -66,18 +70,18 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir)
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
}
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
extractExamples(config, debug, dir, epoch, dynamicOracle);
extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold);
machine.saveDicts();
}
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle)
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{
torch::AutoGradMode useGrad(false);
......@@ -129,22 +133,32 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
std::vector<int> candidates;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config))
{
chosenTransition = i;
if (score > bestScore and appliableTransitions[i])
bestScore = score;
}
}
transition = machine.getTransitionSet().getTransition(chosenTransition);
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
candidates.emplace_back(i);
}
if (candidates.size() != 1)
{
fmt::print(stderr, "nbCand = {}\n", candidates.size());
std::cerr << prediction << std::endl;
}
transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]);
}
else
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment