#include "Trainer.hpp" #include "SubConfig.hpp" LossFunction::LossFunction(std::string name) { if (util::lower(name) == "crossentropy") fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); else if (util::lower(name) == "bce") fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "mse") fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); else if (util::lower(name) == "hinge") fct = CustomHingeLoss(); else util::myThrow(fmt::format("unknown loss function name '{}'", name)); } torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold) { auto index = fct.index(); if (index == 0) return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0))); if (index == 1) return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); if (index == 2) return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); if (index == 3) return std::get<3>(fct)(torch::softmax(prediction, 1), gold); util::myThrow("loss is not defined"); return torch::Tensor(); } torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const { auto index = fct.index(); if (index == 0) { auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); gold[0] = goldIndexes.at(0); return gold; } if (index == 1 or index == 2 or index == 3) { auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); for (auto goldIndex : goldIndexes) gold[goldIndex] = 1; return gold; } util::myThrow("loss is not defined"); return torch::Tensor(); } Trainer::Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName) : machine(machine), batchSize(batchSize), lossFct(lossFunctionName) { } void Trainer::makeDataLoader(std::filesystem::path dir) { trainDataset.reset(new Dataset(dir)); dataLoader = torch::data::make_data_loader(*trainDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::makeDevDataLoader(std::filesystem::path dir) { devDataset.reset(new Dataset(dir)); devDataLoader = torch::data::make_data_loader(*devDataset, torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); } void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { SubConfig config(goldConfig, goldConfig.getNbLines()); machine.trainMode(false); extractExamples(config, debug, dir, epoch, dynamicOracle, explorationThreshold); machine.saveDicts(); } void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold) { torch::AutoGradMode useGrad(false); int maxNbExamplesPerFile = 50000; std::map<std::string, Examples> examplesPerState; std::filesystem::create_directories(dir); config.addPredicted(machine.getPredicted()); config.setStrategy(machine.getStrategyDefinition()); config.setState(config.getStrategy().getInitialState()); machine.getClassifier()->setState(config.getState()); auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}.{}", epoch, dynamicOracle); if (std::filesystem::exists(currentEpochAllExtractedFile)) return; fmt::print(stderr, "[{}] Starting to extract examples{}\n", util::getTime(), dynamicOracle ? ", dynamic oracle" : ""); int totalNbExamples = 0; while (true) { if (debug) config.printForDebug(stderr); if (machine.hasSplitWordTransitionSet()) config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions)); auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(config); config.setAppliableTransitions(appliableTransitions); std::vector<std::vector<long>> context; try { context = machine.getClassifier()->getNN()->extractContext(config); } catch(std::exception & e) { util::myThrow(fmt::format("Failed to extract context : {}", e.what())); } Transition * transition = nullptr; auto goldTransitions = machine.getTransitionSet().getBestAppliableTransitions(config, appliableTransitions, true or dynamicOracle); Transition * goldTransition = goldTransitions[std::rand()%goldTransitions.size()]; int nbClasses = machine.getTransitionSet().size(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = torch::softmax(machine.getClassifier()->getNN()(neuralInput), -1).squeeze(); float bestScore = std::numeric_limits<float>::min(); std::vector<int> candidates; for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if (score > bestScore and appliableTransitions[i]) bestScore = score; } for (unsigned int i = 0; i < prediction.size(0); i++) { float score = prediction[i].item<float>(); if (appliableTransitions[i] and bestScore - score <= explorationThreshold) candidates.emplace_back(i); } transition = machine.getTransitionSet().getTransition(candidates[std::rand()%candidates.size()]); } else { transition = goldTransition; } if (!transition or !goldTransition) { config.printForDebug(stderr); util::myThrow("No transition appliable !"); } totalNbExamples += context.size(); if (totalNbExamples >= (int)safetyNbExamplesMax) util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); std::vector<int> goldIndexes; for (auto & t : goldTransitions) goldIndexes.emplace_back(machine.getTransitionSet().getTransitionIndex(t)); examplesPerState[config.getState()].addContext(context); examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); transition->apply(config); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName()); if (debug) fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second); if (movement == Strategy::endMovement) break; config.setState(movement.first); machine.getClassifier()->setState(movement.first); config.moveWordIndexRelaxed(movement.second); if (config.needsUpdate()) config.update(); } for (auto & it : examplesPerState) it.second.saveIfNeeded(it.first, dir, 0, epoch, dynamicOracle); std::FILE * f = std::fopen(currentEpochAllExtractedFile.c_str(), "w"); if (!f) util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str())); std::fclose(f); fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples)); } float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples) { constexpr int printInterval = 50; int nbExamplesProcessed = 0; int totalNbExamplesProcessed = 0; float totalLoss = 0.0; float lossSoFar = 0.0; torch::AutoGradMode useGrad(train); machine.trainMode(train); auto pastTime = std::chrono::high_resolution_clock::now(); for (auto & batch : *loader) { if (train) machine.getClassifier()->getOptimizer().zero_grad(); auto data = std::get<0>(batch); auto labels = std::get<1>(batch); auto state = std::get<2>(batch); machine.getClassifier()->setState(state); auto prediction = machine.getClassifier()->getNN()(data); if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); auto loss = machine.getClassifier()->getLossMultiplier()*lossFct(prediction, labels); float lossAsFloat = 0.0; try { lossAsFloat = loss.item<float>(); } catch(std::exception & e) {util::myThrow(e.what());} totalLoss += lossAsFloat; lossSoFar += lossAsFloat; if (train) { loss.backward(); machine.getClassifier()->getOptimizer().step(); } totalNbExamplesProcessed += labels.size(0); if (printAdvancement) { nbExamplesProcessed += labels.size(0); if (nbExamplesProcessed >= printInterval) { auto actualTime = std::chrono::high_resolution_clock::now(); double seconds = std::chrono::duration<double, std::milli>(actualTime-pastTime).count() / 1000.0; pastTime = actualTime; auto speed = (int)(nbExamplesProcessed/seconds); auto progression = 100.0*totalNbExamplesProcessed / nbExamples; auto statusStr = fmt::format("{:6.2f}% loss={:<7.3f} speed={:<6}ex/s", progression, lossSoFar, speed); if (train) fmt::print(stderr, "\r{:80}\rtraining : {}", "", statusStr); else fmt::print(stderr, "\r{:80}\reval on dev : {}", "", statusStr); lossSoFar = 0; nbExamplesProcessed = 0; } } } return totalLoss / nbExamples; } float Trainer::epoch(bool printAdvancement) { return processDataset(dataLoader, true, printAdvancement, trainDataset->size().value()); } float Trainer::evalOnDev(bool printAdvancement) { return processDataset(devDataLoader, false, printAdvancement, devDataset->size().value()); } void Trainer::Examples::saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int epoch, bool dynamicOracle) { if (currentExampleIndex-lastSavedIndex < (int)threshold) return; if (contexts.empty()) return; int nbClasses = classes[0].size(0); auto tensorToSave = torch::cat({torch::stack(contexts), torch::stack(classes)}, 1); auto filename = fmt::format("{}-{}_{}-{}.{}.{}.tensor", state, nbClasses, lastSavedIndex, currentExampleIndex-1, epoch, dynamicOracle); torch::save(tensorToSave, dir/filename); lastSavedIndex = currentExampleIndex; contexts.clear(); classes.clear(); } void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) { for (auto & element : context) contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone()); currentExampleIndex += context.size(); } void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes) { auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes); while (classes.size() < contexts.size()) classes.emplace_back(gold); } Trainer::TrainAction Trainer::str2TrainAction(const std::string & s) { if (s == "ExtractGold") return TrainAction::ExtractGold; if (s == "ExtractDynamic") return TrainAction::ExtractDynamic; if (s == "DeleteExamples") return TrainAction::DeleteExamples; if (s == "ResetOptimizer") return TrainAction::ResetOptimizer; if (s == "ResetParameters") return TrainAction::ResetParameters; if (s == "Save") return TrainAction::Save; util::myThrow(fmt::format("unknown TrainAction '{}'", s)); return TrainAction::ExtractGold; }