From 8ec956e65b92712146b113e5fb63e7aa01b23d75 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Sun, 21 Jun 2020 23:56:39 +0200 Subject: [PATCH] Added bce and mse losses --- reading_machine/include/TransitionSet.hpp | 2 +- reading_machine/src/TransitionSet.cpp | 24 ++-- torch_modules/include/ConfigDataset.hpp | 3 +- torch_modules/src/ConfigDataset.cpp | 10 +- trainer/include/Trainer.hpp | 19 ++- trainer/src/MacaonTrain.cpp | 5 +- trainer/src/Trainer.cpp | 146 +++++++++++----------- 7 files changed, 116 insertions(+), 93 deletions(-) diff --git a/reading_machine/include/TransitionSet.hpp b/reading_machine/include/TransitionSet.hpp index 1d1bc75..936fa88 100644 --- a/reading_machine/include/TransitionSet.hpp +++ b/reading_machine/include/TransitionSet.hpp @@ -21,7 +21,7 @@ class TransitionSet TransitionSet(const std::vector<std::string> & filenames); TransitionSet(const std::string & filename); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c, bool dynamic = false); - Transition * getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false); + std::vector<Transition *> getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic = false); std::vector<Transition *> getNAppliableTransitions(const Config & c, int n); std::vector<int> getAppliableTransitions(const Config & c); std::size_t getTransitionIndex(const Transition * transition) const; diff --git a/reading_machine/src/TransitionSet.cpp b/reading_machine/src/TransitionSet.cpp index 8c5c1a8..8f146e7 100644 --- a/reading_machine/src/TransitionSet.cpp +++ b/reading_machine/src/TransitionSet.cpp @@ -80,28 +80,31 @@ std::vector<int> TransitionSet::getAppliableTransitions(const Config & c) return result; } -Transition * TransitionSet::getBestAppliableTransition(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic) +std::vector<Transition *> TransitionSet::getBestAppliableTransitions(const Config & c, const std::vector<int> & appliableTransitions, bool dynamic) { - Transition * result = nullptr; int bestCost = std::numeric_limits<int>::max(); + std::vector<Transition *> result; + std::vector<int> costs(transitions.size()); for (unsigned int i = 0; i < transitions.size(); i++) { if (!appliableTransitions[i]) + { + costs[i] = std::numeric_limits<int>::max(); continue; + } int cost = dynamic ? transitions[i].getCostDynamic(c) : transitions[i].getCostStatic(c); - if (cost == 0) - return &transitions[i]; - + costs[i] = cost; if (cost < bestCost) - { - result = &transitions[i]; bestCost = cost; - } } + for (unsigned int i = 0; i < transitions.size(); i++) + if (costs[i] == bestCost) + result.emplace_back(&transitions[i]); + return result; } @@ -115,7 +118,10 @@ std::size_t TransitionSet::getTransitionIndex(const Transition * transition) con if (!transition) util::myThrow("transition is null"); - return transition - &transitions[0]; + int index = transition - &transitions[0]; + if (index < 0 or index >= (int)transitions.size()) + util::myThrow(fmt::format("transition index '{}' out of bounds [0;{}[", index, transitions.size())); + return index; } Transition * TransitionSet::getTransition(std::size_t index) diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp index 7ea2335..4090a80 100644 --- a/torch_modules/include/ConfigDataset.hpp +++ b/torch_modules/include/ConfigDataset.hpp @@ -17,8 +17,9 @@ class ConfigDataset : public torch::data::datasets::StatefulDataset<ConfigDatase int nextIndexToGive{0}; std::size_t size_{0}; std::size_t nbGiven{0}; + int nbClasses; - Holder(std::string state); + Holder(std::string state, int nbClasses); void addFile(std::string filename, int filesize); void reset(); std::size_t size() const; diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp index d15064e..2b42eef 100644 --- a/torch_modules/src/ConfigDataset.cpp +++ b/torch_modules/src/ConfigDataset.cpp @@ -10,13 +10,15 @@ ConfigDataset::ConfigDataset(std::filesystem::path dir) if (stem == "extracted") continue; auto underSplit = util::split(stem, '_'); - auto state = util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)); + auto stateAndNbClasses = util::split(util::join("_", std::vector<std::string>(underSplit.begin(), underSplit.end()-1)), '-'); + auto state = stateAndNbClasses[0]; + auto nbClasses = std::stoi(stateAndNbClasses[1]); auto splited = util::split(underSplit.back(), '-'); int fileSize = 1 + std::stoi(splited[1]) - std::stoi(splited[0]); size_ += fileSize; if (!holders.count(state)) { - holders.emplace(state, Holder(state)); + holders.emplace(state, Holder(state, nbClasses)); order.emplace_back(state); } holders.at(state).addFile(entry.path().string(), fileSize); @@ -111,10 +113,10 @@ c10::optional<std::tuple<torch::Tensor,torch::Tensor,std::string>> ConfigDataset nbGiven += nbElementsToGive; auto batch = loadedTensor.narrow(0, nextIndexToGive, nbElementsToGive); nextIndexToGive += nbElementsToGive; - return std::make_tuple(batch.narrow(1, 0, batch.size(1)-1), batch.narrow(1, batch.size(1)-1, 1), state); + return std::make_tuple(batch.narrow(1, 0, batch.size(1)-nbClasses), batch.narrow(1, batch.size(1)-nbClasses, nbClasses), state); } -ConfigDataset::Holder::Holder(std::string state) : state(state) +ConfigDataset::Holder::Holder(std::string state, int nbClasses) : state(state), nbClasses(nbClasses) { } diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index d566747..ad14ef6 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -5,6 +5,19 @@ #include "ConfigDataset.hpp" #include "SubConfig.hpp" +class LossFunction +{ + private : + + std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss> fct; + + public : + + LossFunction(std::string name); + torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold); + torch::Tensor getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const; +}; + class Trainer { public : @@ -35,7 +48,7 @@ class Trainer void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle); void addContext(std::vector<std::vector<long>> & context); - void addClass(int goldIndex); + void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes); }; private : @@ -52,16 +65,16 @@ class Trainer DataLoader devDataLoader{nullptr}; std::size_t epochNumber{0}; int batchSize; + LossFunction lossFct; private : void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples); - void fillDicts(SubConfig & config, bool debug); public : - Trainer(ReadingMachine & machine, int batchSize); + Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 8e43cd3..9db5c5a 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -37,6 +37,8 @@ po::options_description MacaonTrain::getOptionsDescription() "Reading machine file content") ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), "Description of what should happen during training") + ("loss", po::value<std::string>()->default_value("CrossEntropy"), + "Loss function to use during training : CrossEntropy | bce | mse") ("help,h", "Produce this help message"); desc.add(req).add(opt); @@ -125,6 +127,7 @@ int MacaonTrain::main() bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); + auto lossFunction = variables["loss"].as<std::string>(); auto trainStrategy = parseTrainStrategy(trainStrategyStr); @@ -149,7 +152,7 @@ int MacaonTrain::main() BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); - Trainer trainer(machine, batchSize); + Trainer trainer(machine, batchSize, lossFunction); Decoder decoder(machine); if (!util::findFilesByExtension(machinePath.parent_path(), ".dict").empty()) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index af1ef2e..90b65b7 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,7 +1,56 @@ #include "Trainer.hpp" #include "SubConfig.hpp" -Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) +LossFunction::LossFunction(std::string name) +{ + if (util::lower(name) == "crossentropy") + fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "bce") + fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "mse") + fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); + else + util::myThrow(fmt::format("unknown loss function name '{}'", name)); +} + +torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold) +{ + auto index = fct.index(); + + if (index == 0) + return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0))); + if (index == 1) + return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); + if (index == 2) + return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); + + util::myThrow("loss is not defined"); + return torch::Tensor(); +} + +torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const +{ + auto index = fct.index(); + + if (index == 0) + { + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); + gold[0] = goldIndexes.at(0); + return gold; + } + if (index == 1 or index == 2) + { + auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); + for (auto goldIndex : goldIndexes) + gold[goldIndex] = 1; + return gold; + } + + util::myThrow("loss is not defined"); + return torch::Tensor(); +} + +Trainer::Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName) : machine(machine), batchSize(batchSize), lossFct(lossFunctionName) { } @@ -72,9 +121,10 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p } Transition * transition = nullptr; - Transition * goldTransition = nullptr; - goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions, true or dynamicOracle); + auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); + Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()]; + int nbClasses = machine.getTransitionSet().size(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { @@ -107,14 +157,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p util::myThrow("No transition appliable !"); } - int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition); - totalNbExamples += context.size(); if (totalNbExamples >= (int)safetyNbExamplesMax) util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); + std::vector<int> goldIndexes; + for (auto & t : goldTransitions) + goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t)); + examplesPerState[config.getState()].addContext(context); - examplesPerState[config.getState()].addClass(goldIndex); + examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); transition->apply(config); @@ -156,8 +208,6 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance torch::AutoGradMode useGrad(train); machine.trainMode(train); - auto lossFct = torch::nn::CrossEntropyLoss(); - auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) @@ -175,26 +225,27 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); - labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); - auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels); + float lossAsFloat = 0.0; try { - totalLoss += loss.item<float>(); - lossSoFar += loss.item<float>(); + lossAsFloat = loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} + totalLoss += lossAsFloat; + lossSoFar += lossAsFloat; + if (train) { loss.backward(); machine.getClassifier()->getOptimizer().step(); } - totalNbExamplesProcessed += torch::numel(labels); + totalNbExamplesProcessed += labels.size(0); if (printAdvancement) { - nbExamplesProcessed += torch::numel(labels); + nbExamplesProcessed += labels.size(0); if (nbExamplesProcessed >= printInterval) { @@ -234,8 +285,10 @@ void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem: if (contexts.empty()) return; + int nbClasses = classes[0].size(0); + auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); - auto filename = fmt::format("{}_{}-{}.{}.{}.tensor", state, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); + auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); @@ -250,67 +303,12 @@ void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) currentExampleIndex += context.size(); } -void Trainer::Examples::addClass(int goldIndex) +void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes) { - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); - gold[0] = goldIndex; - - while (classes.size() < contexts.size()) - classes.emplace_back(gold); -} - -void Trainer::fillDicts(SubConfig & config, bool debug) -{ - torch::AutoGradMode useGrad(false); - - config.addPredicted(machine.getPredicted()); - config.setStrategy(machine.getStrategyDefinition()); - config.setState(config.getStrategy().getInitialState()); - machine.getClassifier()->setState(config.getState()); + auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes); - while (true) - { - if (debug) - config.printForDebug(stderr); - - if (machine.hasSplitWordTransitionSet()) - config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); - auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); - config.setAppliableTransitions(appliableTransitions); - - try - { - machine.getClassifier()->getNN()->extractContext(config); - } catch(std::exception & e) - { - util::myThrow(fmt::format("Failed to extract context : {}", e.what())); - } - - Transition * goldTransition = nullptr; - goldTransition = machine.getTransitionSet().getBestAppliableTransition(config, appliableTransitions); - - if (!goldTransition) - { - config.printForDebug(stderr); - util::myThrow("No transition appliable !"); - } - - goldTransition->apply(config); - config.addToHistory(goldTransition->getName()); - - auto movement = config.getStrategy().getMovement(config, goldTransition->getName()); - if (debug) - fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second); - if (movement == Strategy::endMovement) - break; - - config.setState(movement.first); - machine.getClassifier()->setState(movement.first); - config.moveWordIndexRelaxed(movement.second); - - if (config.needsUpdate()) - config.update(); - } + while (classes.size() < contexts.size()) + classes.emplace_back(gold); } Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) -- GitLab