Skip to content
Snippets Groups Projects
Commit 048c959c authored by Franck Dary's avatar Franck Dary
Browse files

Added custom hinge loss and added explorationThreshold program argument

parent 8ec956e6
No related branches found
No related tags found
No related merge requests found
#ifndef CUSTOMHINGELOSS__H
#define CUSTOMHINGELOSS__H
#include <torch/torch.h>
class CustomHingeLoss
{
public :
torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold);
};
#endif
#include "CustomHingeLoss.hpp"
torch::Tensor CustomHingeLoss::operator()(torch::Tensor prediction, torch::Tensor gold)
{
torch::Tensor loss = torch::zeros(1);
for (unsigned int i = 0; i < prediction.size(0); i++)
{
loss += torch::relu(1 - torch::max(gold[i]*prediction[i])
+ torch::max((1-gold[i])*prediction[i]));
}
loss /= prediction.size(0);
return loss;
}
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
#include "ReadingMachine.hpp" #include "ReadingMachine.hpp"
#include "ConfigDataset.hpp" #include "ConfigDataset.hpp"
#include "SubConfig.hpp" #include "SubConfig.hpp"
#include "CustomHingeLoss.hpp"
class LossFunction class LossFunction
{ {
private : private :
std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct; std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct;
public : public :
...@@ -69,13 +70,13 @@ class Trainer ...@@ -69,13 +70,13 @@ class Trainer
private : private :
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
public : public :
Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName); Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold);
void makeDataLoader(std::filesystem::path dir); void makeDataLoader(std::filesystem::path dir);
void makeDevDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir);
float epoch(bool printAdvancement); float epoch(bool printAdvancement);
......
...@@ -33,12 +33,14 @@ po::options_description MacaonTrain::getOptionsDescription() ...@@ -33,12 +33,14 @@ po::options_description MacaonTrain::getOptionsDescription()
"Number of training epochs") "Number of training epochs")
("batchSize", po::value<int>()->default_value(64), ("batchSize", po::value<int>()->default_value(64),
"Number of examples per batch") "Number of examples per batch")
("explorationThreshold", po::value<float>()->default_value(0.1),
"Maximum probability difference with the best scoring transition, for a transition to be explored during dynamic extraction of dataset")
("machine", po::value<std::string>()->default_value(""), ("machine", po::value<std::string>()->default_value(""),
"Reading machine file content") "Reading machine file content")
("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"),
"Description of what should happen during training") "Description of what should happen during training")
("loss", po::value<std::string>()->default_value("CrossEntropy"), ("loss", po::value<std::string>()->default_value("CrossEntropy"),
"Loss function to use during training : CrossEntropy | bce | mse") "Loss function to use during training : CrossEntropy | bce | mse | hinge")
("help,h", "Produce this help message"); ("help,h", "Produce this help message");
desc.add(req).add(opt); desc.add(req).add(opt);
...@@ -128,6 +130,7 @@ int MacaonTrain::main() ...@@ -128,6 +130,7 @@ int MacaonTrain::main()
auto machineContent = variables["machine"].as<std::string>(); auto machineContent = variables["machine"].as<std::string>();
auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>();
auto lossFunction = variables["loss"].as<std::string>(); auto lossFunction = variables["loss"].as<std::string>();
auto explorationThreshold = variables["explorationThreshold"].as<float>();
auto trainStrategy = parseTrainStrategy(trainStrategyStr); auto trainStrategy = parseTrainStrategy(trainStrategyStr);
...@@ -211,11 +214,11 @@ int MacaonTrain::main() ...@@ -211,11 +214,11 @@ int MacaonTrain::main()
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)) if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractGold) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic))
{ {
machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open); machine.setDictsState(trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic) ? Dict::State::Closed : Dict::State::Open);
trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); trainer.createDataset(goldConfig, debug, modelPath/"examples/train", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
if (!computeDevScore) if (!computeDevScore)
{ {
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic)); trainer.createDataset(devGoldConfig, debug, modelPath/"examples/dev", currentEpoch, trainStrategy[currentEpoch].count(Trainer::TrainAction::ExtractDynamic), explorationThreshold);
} }
} }
if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer)) if (trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetParameters) or trainStrategy[currentEpoch].count(Trainer::TrainAction::ResetOptimizer))
......
...@@ -9,6 +9,8 @@ LossFunction::LossFunction(std::string name) ...@@ -9,6 +9,8 @@ LossFunction::LossFunction(std::string name)
fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean));
else if (util::lower(name) == "mse") else if (util::lower(name) == "mse")
fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean));
else if (util::lower(name) == "hinge")
fct = CustomHingeLoss();
else else
util::myThrow(fmt::format("unknown loss function name '{}'", name)); util::myThrow(fmt::format("unknown loss function name '{}'", name));
} }
...@@ -23,6 +25,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g ...@@ -23,6 +25,8 @@ torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor g
return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
if (index == 2) if (index == 2)
return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat));
if (index == 3)
return std::get<3>(fct)(torch::softmax(prediction, 1), gold);
util::myThrow("loss is not defined"); util::myThrow("loss is not defined");
return torch::Tensor(); return torch::Tensor();
...@@ -38,7 +42,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std:: ...@@ -38,7 +42,7 @@ torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::
gold[0] = goldIndexes.at(0); gold[0] = goldIndexes.at(0);
return gold; return gold;
} }
if (index == 1 or index == 2) if (index == 1 or index == 2 or index == 3)
{ {
auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong));
for (auto goldIndex : goldIndexes) for (auto goldIndex : goldIndexes)
...@@ -66,18 +70,18 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir) ...@@ -66,18 +70,18 @@ void Trainer::makeDevDataLoader(std::filesystem::path dir)
devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
} }
void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{ {
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false); machine.trainMode(false);
extractExamples(config, debug, dir, epoch, dynamicOracle); extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold);
machine.saveDicts(); machine.saveDicts();
} }
void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle) void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold)
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
...@@ -129,22 +133,32 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -129,22 +133,32 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{ {
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze();
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min(); float bestScore = std::numeric_limits<float>::min();
std::vector<int> candidates;
for (unsigned int i = 0; i < prediction.size(0); i++) for (unsigned int i = 0; i < prediction.size(0); i++)
{ {
float score = prediction[i].item<float>(); float score = prediction[i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and machine.getTransitionSet().getTransition(i)->appliable(config)) if (score > bestScore and appliableTransitions[i])
{
chosenTransition = i;
bestScore = score; bestScore = score;
} }
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and bestScore - score <= explorationThreshold)
candidates.emplace_back(i);
}
if (candidates.size() != 1)
{
fmt::print(stderr, "nbCand = {}\n", candidates.size());
std::cerr << prediction << std::endl;
} }
transition = machine.getTransitionSet().getTransition(chosenTransition); transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]);
} }
else else
{ {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment