diff --git a/common/include/util.hpp b/common/include/util.hpp index 90056c70617d9603db5c9b5037982e4da1048d2f..82620a56f08adcc1e8d354ad84a1af9c238d33c8 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -16,6 +16,9 @@ namespace util { + +constexpr float float2longScale = 10000; + void warning(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); void error(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current()); void error(const std::exception & e, const std::experimental::source_location & location = std::experimental::source_location::current()); @@ -31,7 +34,7 @@ utf8string splitAsUtf8(std::string_view s); std::string int2HumanStr(int number); -std::string shrink(const std::string & s, int printedSize); +std::string shrink(std::string s, int printedSize); std::string strip(const std::string & s); @@ -48,6 +51,9 @@ bool isNumber(const std::string & s); std::string getTime(); +long float2long(float f); +float long2float(long l); + template <typename T> bool isEmpty(const std::vector<T> & s) { diff --git a/common/src/util.cpp b/common/src/util.cpp index dd9a982104d3b82a224d7ea21d09034f3ec67734..2ce33306431ea96551d282fbd5c637690206ecd3 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -4,6 +4,16 @@ #include <algorithm> #include "upper2lower" +float util::long2float(long l) +{ + return l / util::float2longScale; +} + +long util::float2long(float f) +{ + return f * util::float2longScale; +} + int util::printedLength(std::string_view s) { return splitAsUtf8(s).size(); @@ -91,10 +101,20 @@ util::utf8string util::splitAsUtf8(std::string_view s) return result; } -std::string util::shrink(const std::string & s, int printedSize) +std::string util::shrink(std::string s, int printedSize) { static const std::string filler = "…"; + if (printedLength(s) <= printedSize) + return s; + + try + { + float value = std::stof(s); + s = fmt::format("{:{}.3f}", value, printedSize); + } + catch (std::exception &) {} + if (printedLength(s) <= printedSize) return s; diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 067587979e64eb219fcd1955a0e8d76341de1f03..16d5a32c17d220c75e5c6ab66ab3ca799dbd0e45 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -50,7 +50,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto context = classifier.getNN()->extractContext(elements[index].config).back(); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); - auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); + auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++) @@ -76,13 +76,15 @@ void Beam::update(ReadingMachine & machine, bool debug) { elements.emplace_back(elements[index], scoresOfTransitions[i].second); elements.back().name.push_back(std::to_string(i)); - elements.back().totalProbability += scoresOfTransitions[i].first; + elements.back().totalProbability += classifier.isRegression() ? 1.0 : scoresOfTransitions[i].first; + elements.back().config.setChosenActionScore(scoresOfTransitions[i].first); elements.back().nbTransitions++; elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions; } elements[index].nextTransition = scoresOfTransitions[0].second; - elements[index].totalProbability += scoresOfTransitions[0].first; + elements[index].totalProbability += classifier.isRegression() ? 1.0 : scoresOfTransitions[0].first; + elements[index].config.setChosenActionScore(scoresOfTransitions[0].first); elements[index].nbTransitions++; elements[index].name.push_back("0"); elements[index].meanProbability = 0.0; diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 71cf7d5994b32dc8f8132a8db75fa2dee2f4d177..dbf344b1e3c24acacc60ef0b2a5da3368f49e1a3 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -47,6 +47,7 @@ class Action static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); + static Action writeScore(const std::string & colName, Config::Object object, int relativeIndex); static Action pushWordIndexOnStack(); static Action popStack(int relIndex); static Action emptyStack(); diff --git a/reading_machine/include/Classifier.hpp b/reading_machine/include/Classifier.hpp index 3e5e9507175db5cc28e3af391dc622da6c5f4ec2..f8f79fe88e87c10f55c0279e9e69a3ac0cde451d 100644 --- a/reading_machine/include/Classifier.hpp +++ b/reading_machine/include/Classifier.hpp @@ -5,6 +5,7 @@ #include <filesystem> #include "TransitionSet.hpp" #include "NeuralNetwork.hpp" +#include "LossFunction.hpp" class Classifier { @@ -24,6 +25,8 @@ class Classifier std::string state; std::vector<std::string> states; std::filesystem::path path; + bool regression{false}; + LossFunction lossFct; private : @@ -49,6 +52,8 @@ class Classifier void saveDicts(); void saveBest(); void saveLast(); + bool isRegression() const; + LossFunction & getLossFunction(); }; #endif diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index ecea742a7e18c384965824c25da40b0df3a7e0cf..7e660d399080fc4cee2f0cc8410705ec4dc95b9f 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -54,6 +54,7 @@ class Config String state{"NONE"}; boost::circular_buffer<String> history{10}; boost::circular_buffer<std::size_t> stack{50}; + float chosenActionScore{0.0}; std::vector<std::string> extraColumns{isMultiColName, childsColName, sentIdColName, EOSColName}; std::set<std::string> predicted; int lastPoppedStack{-1}; @@ -167,6 +168,8 @@ class Config Strategy & getStrategy(); std::size_t getCurrentSentenceStartRawInput() const; void setCurrentSentenceStartRawInput(std::size_t value); + void setChosenActionScore(float chosenActionScore); + float getChosenActionScore() const; }; #endif diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 58549791d0031abcbb0d5f192aa65d12854791f0..3e76c724245800586f78482e0f974689f93b8fb6 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -56,6 +56,7 @@ class Transition void initNothing(std::string col, std::string obj, std::string index); void initLowercase(std::string col, std::string obj, std::string index); void initLowercaseIndex(std::string col, std::string obj, std::string index, std::string inIndex); + void initWriteScore(std::string colName, std::string object, std::string index); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index e04b5f6c13f6610d0eaf9e96c8649868f895a639..2ef4baa1cde8316079c87070646db897a18dd190 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -1053,3 +1053,38 @@ Action Action::lowercaseIndex(std::string col, Config::Object obj, int index, in return {Type::Write, apply, undo, appliable}; } + +Action Action::writeScore(const std::string & colName, Config::Object object, int relativeIndex) +{ + auto apply = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + float score = config.getChosenActionScore(); + if (score != std::numeric_limits<float>::min()) + return addHypothesis(colName, lineIndex, fmt::format("{}", score)).apply(config, a); + else + return addHypothesis(colName, lineIndex, config.getConst(colName, lineIndex, 0)).apply(config, a); + }; + + auto undo = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, "").undo(config, a); + }; + + auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a) + { + if (!config.hasRelativeWordIndex(object, relativeIndex)) + return false; + + int lineIndex = config.getRelativeWordIndex(object, relativeIndex); + + return addHypothesis(colName, lineIndex, "").appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; + +} + diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp index e91351e738de6e7fcfdf73abc486b85dddff9798..99fbdd6c02c637c7debf2ebbf0dd60a589c811af 100644 --- a/reading_machine/src/Classifier.cpp +++ b/reading_machine/src/Classifier.cpp @@ -143,6 +143,28 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition) optimizerParameters = sm.str(2); })) util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Optimizer :) " + util::join("|", knownOptimizers))); + + curIndex++; + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Type :|)(?:(?:\\s|\\t)*)(.+)"), definition[curIndex], [&curIndex,this](auto sm) + { + auto type = sm.str(1); + if (util::lower(type) == "regression") + regression = true; + else if (util::lower(type) == "classification") + regression = false; + else + util::myThrow(fmt::format("Invalid type '{}' expected 'classification' or 'regression'", type)); + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Type :) (classification | regression)" )); + + curIndex++; + + if (curIndex >= definition.size() || !util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)(?:Loss :|)(?:(?:\\s|\\t)*)(.+)"), definition[curIndex], [&curIndex,this](auto sm) + { + lossFct.init(sm.str(1)); + })) + util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", curIndex < definition.size() ? definition[curIndex] : "", "(Loss :) lossName" )); } void Classifier::loadOptimizer() @@ -250,3 +272,13 @@ void Classifier::saveLast() saveOptimizer(); } +bool Classifier::isRegression() const +{ + return regression; +} + +LossFunction & Classifier::getLossFunction() +{ + return lossFct; +} + diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 5d095dcd1d039b596dde429795137c9c3b2484f0..7ad813bec25275b9defe30d7bb4ffbaca0674977 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -795,3 +795,13 @@ void Config::setCurrentSentenceStartRawInput(std::size_t value) currentSentenceStartRawInput = value; } +void Config::setChosenActionScore(float chosenActionScore) +{ + this->chosenActionScore = chosenActionScore; +} + +float Config::getChosenActionScore() const +{ + return chosenActionScore; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 2b97ca0ee0fa8abb3047c06b24f1224e75acfecb..23097ff5e9bbe36d8adc11396d58a6979781d994 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -7,6 +7,8 @@ Transition::Transition(const std::string & name) { {std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}}, + {std::regex("WRITESCORE ([bs])\\.(.+) (.+)"), + [this](auto sm){(initWriteScore(sm[3], sm[1], sm[2]));}}, {std::regex("ADD ([bs])\\.(.+) (.+) (.+)"), [this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}}, {std::regex("eager_SHIFT"), @@ -164,6 +166,21 @@ void Transition::initWrite(std::string colName, std::string object, std::string costStatic = costDynamic; } +void Transition::initWriteScore(std::string colName, std::string object, std::string index) +{ + auto objectValue = Config::str2object(object); + int indexValue = std::stoi(index); + + sequence.emplace_back(Action::writeScore(colName, objectValue, indexValue)); + + costDynamic = [](const Config &) + { + return 0; + }; + + costStatic = costDynamic; +} + void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) { auto objectValue = Config::str2object(object); diff --git a/torch_modules/include/LossFunction.hpp b/torch_modules/include/LossFunction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b845ab3104fc7af65ec05072956f52566be4b8c3 --- /dev/null +++ b/torch_modules/include/LossFunction.hpp @@ -0,0 +1,23 @@ +#ifndef LOSSFUNCTION__H +#define LOSSFUNCTION__H + +#include <variant> +#include "torch/torch.h" +#include "CustomHingeLoss.hpp" + +class LossFunction +{ + private : + + std::string name{"_undefined_loss_"}; + std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct; + + public : + + void init(std::string name); + torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold); + torch::Tensor getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const; +}; + +#endif + diff --git a/torch_modules/src/LossFunction.cpp b/torch_modules/src/LossFunction.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d39203b11978d6e26c01ae5ee1c1c524fb9dad67 --- /dev/null +++ b/torch_modules/src/LossFunction.cpp @@ -0,0 +1,64 @@ +#include "LossFunction.hpp" +#include "util.hpp" + +void LossFunction::init(std::string name) +{ + this->name = name; + + if (util::lower(name) == "crossentropy") + fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "bce") + fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "mse") + fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); + else if (util::lower(name) == "hinge") + fct = CustomHingeLoss(); + else + util::myThrow(fmt::format("unknown loss function name '{}' available losses are 'crossentropy, bce, mse, hinge'", name)); +} + +torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold) +{ + try + { + auto index = fct.index(); + + if (index == 0) + return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0))); + if (index == 1) + return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); + if (index == 2) + return std::get<2>(fct)(prediction, gold); + if (index == 3) + return std::get<3>(fct)(torch::softmax(prediction, 1), gold); + } catch (std::exception & e) + { + util::myThrow(fmt::format("computing loss '{}' caught '{}'", name, e.what())); + } + + util::myThrow("loss is not defined"); + return torch::Tensor(); +} + +torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<long> & goldIndexes) const +{ + auto index = fct.index(); + + if (index == 0 or index == 2) + { + auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); + gold[0] = goldIndexes.at(0); + return gold; + } + if (index == 1 or index == 3) + { + auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); + for (auto goldIndex : goldIndexes) + gold[goldIndex] = 1; + return gold; + } + + util::myThrow("loss is not defined"); + return torch::Tensor(); +} + diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index 25d48f4b567bbc0ed5934c2755cc4db17e36d114..a61eaf7c6637afb61525c81377252e386486d2c0 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -4,20 +4,6 @@ #include "ReadingMachine.hpp" #include "ConfigDataset.hpp" #include "SubConfig.hpp" -#include "CustomHingeLoss.hpp" - -class LossFunction -{ - private : - - std::variant<torch::nn::CrossEntropyLoss, torch::nn::BCELoss, torch::nn::MSELoss, CustomHingeLoss> fct; - - public : - - LossFunction(std::string name); - torch::Tensor operator()(torch::Tensor prediction, torch::Tensor gold); - torch::Tensor getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const; -}; class Trainer { @@ -49,7 +35,7 @@ class Trainer void saveIfNeeded(const std::string & state, std::filesystem::path dir, std::size_t threshold, int currentEpoch, bool dynamicOracle); void addContext(std::vector<std::vector<long>> & context); - void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes); + void addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes); }; private : @@ -66,7 +52,6 @@ class Trainer DataLoader devDataLoader{nullptr}; std::size_t epochNumber{0}; int batchSize; - LossFunction lossFct; private : @@ -75,7 +60,7 @@ class Trainer public : - Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName); + Trainer(ReadingMachine & machine, int batchSize); void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, bool dynamicOracle, float explorationThreshold); void makeDataLoader(std::filesystem::path dir); void makeDevDataLoader(std::filesystem::path dir); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 5d0dbaa2e30e9edf761ad4a1e33113b106a5a0fc..923bbff97e1936c51a556f402d6449520454348a 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -40,8 +40,6 @@ po::options_description MacaonTrain::getOptionsDescription() "Reading machine file content") ("trainStrategy", po::value<std::string>()->default_value("0,ExtractGold,ResetParameters"), "Description of what should happen during training") - ("loss", po::value<std::string>()->default_value("CrossEntropy"), - "Loss function to use during training : CrossEntropy | bce | mse | hinge") ("seed", po::value<int>()->default_value(100), "Number of examples per batch") ("scaleGrad", "Scale embedding's gradient with its frequence in the minibatch") @@ -135,7 +133,6 @@ int MacaonTrain::main() bool computeDevScore = variables.count("devScore") == 0 ? false : true; auto machineContent = variables["machine"].as<std::string>(); auto trainStrategyStr = variables["trainStrategy"].as<std::string>(); - auto lossFunction = variables["loss"].as<std::string>(); auto explorationThreshold = variables["explorationThreshold"].as<float>(); auto seed = variables["seed"].as<int>(); WordEmbeddingsImpl::setMaxNorm(variables["maxNorm"].as<float>()); @@ -167,7 +164,7 @@ int MacaonTrain::main() BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); - Trainer trainer(machine, batchSize, lossFunction); + Trainer trainer(machine, batchSize); Decoder decoder(machine); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 40a7c77fd2d828a427ddd0e882cfbac85c90bc8e..d78b98038987d120937448ac0b8eb9a9c1a2c3bf 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -1,60 +1,7 @@ #include "Trainer.hpp" #include "SubConfig.hpp" -LossFunction::LossFunction(std::string name) -{ - if (util::lower(name) == "crossentropy") - fct = torch::nn::CrossEntropyLoss(torch::nn::CrossEntropyLossOptions().reduction(torch::kMean)); - else if (util::lower(name) == "bce") - fct = torch::nn::BCELoss(torch::nn::BCELossOptions().reduction(torch::kMean)); - else if (util::lower(name) == "mse") - fct = torch::nn::MSELoss(torch::nn::MSELossOptions().reduction(torch::kMean)); - else if (util::lower(name) == "hinge") - fct = CustomHingeLoss(); - else - util::myThrow(fmt::format("unknown loss function name '{}'", name)); -} - -torch::Tensor LossFunction::operator()(torch::Tensor prediction, torch::Tensor gold) -{ - auto index = fct.index(); - - if (index == 0) - return std::get<0>(fct)(prediction, gold.reshape(gold.dim() == 0 ? 1 : gold.size(0))); - if (index == 1) - return std::get<1>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); - if (index == 2) - return std::get<2>(fct)(torch::softmax(prediction, 1), gold.to(torch::kFloat)); - if (index == 3) - return std::get<3>(fct)(torch::softmax(prediction, 1), gold); - - util::myThrow("loss is not defined"); - return torch::Tensor(); -} - -torch::Tensor LossFunction::getGoldFromClassesIndexes(int nbClasses, const std::vector<int> & goldIndexes) const -{ - auto index = fct.index(); - - if (index == 0) - { - auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong)); - gold[0] = goldIndexes.at(0); - return gold; - } - if (index == 1 or index == 2 or index == 3) - { - auto gold = torch::zeros(nbClasses, torch::TensorOptions(torch::kLong)); - for (auto goldIndex : goldIndexes) - gold[goldIndex] = 1; - return gold; - } - - util::myThrow("loss is not defined"); - return torch::Tensor(); -} - -Trainer::Trainer(ReadingMachine & machine, int batchSize, std::string lossFunctionName) : machine(machine), batchSize(batchSize), lossFct(lossFunctionName) +Trainer::Trainer(ReadingMachine & machine, int batchSize) : machine(machine), batchSize(batchSize) { } @@ -134,13 +81,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p goldTransitions[std::rand()%goldTransitions.size()]; int nbClasses = machine.getTransitionSet(config.getState()).size(); + + float bestScore = std::numeric_limits<float>::min(); if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); - float bestScore = std::numeric_limits<float>::min(); std::vector<int> candidates; for (unsigned int i = 0; i < prediction.size(0); i++) @@ -170,18 +118,42 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p util::myThrow("No transition appliable !"); } + std::vector<long> goldIndexes; + + float regressionTarget = 0.0; + if (machine.getClassifier(config.getState())->isRegression()) + { + auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); + auto splited = util::split(transition->getName(), ' '); + if (splited.size() != 3 or splited[0] != "WRITESCORE") + util::myThrow(errMessage); + auto col = splited[2]; + splited = util::split(splited[1], '.'); + if (splited.size() != 2) + util::myThrow(errMessage); + auto object = Config::str2object(splited[0]); + int index = std::stoi(splited[1]); + + regressionTarget = std::stof(config.getConst(col, config.getRelativeWordIndex(object, index), 0)); + goldIndexes.emplace_back(util::float2long(regressionTarget)); + } + else + { + for (auto & t : goldTransitions) + goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); + + } + totalNbExamples += context.size(); if (totalNbExamples >= (int)safetyNbExamplesMax) util::myThrow(fmt::format("Trying to extract more examples than the limit ({})", util::int2HumanStr(safetyNbExamplesMax))); - std::vector<int> goldIndexes; - for (auto & t : goldTransitions) - goldIndexes.emplace_back(machine.getTransitionSet(config.getState()).getTransitionIndex(t)); - examplesPerState[config.getState()].addContext(context); - examplesPerState[config.getState()].addClass(lossFct, nbClasses, goldIndexes); + examplesPerState[config.getState()].addClass(machine.getClassifier(config.getState())->getLossFunction(), nbClasses, goldIndexes); examplesPerState[config.getState()].saveIfNeeded(config.getState(), dir, maxNbExamplesPerFile, epoch, dynamicOracle); + config.setChosenActionScore(bestScore); + transition->apply(config); config.addToHistory(transition->getName()); @@ -238,7 +210,13 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance if (prediction.dim() == 1) prediction = prediction.unsqueeze(0); - auto loss = machine.getClassifier(state)->getLossMultiplier()*lossFct(prediction, labels); + if (machine.getClassifier(state)->isRegression()) + { + labels = labels.to(torch::kFloat); + labels /= util::float2longScale; + } + + auto loss = machine.getClassifier(state)->getLossMultiplier()*machine.getClassifier(state)->getLossFunction()(prediction, labels); float lossAsFloat = 0.0; try { @@ -316,7 +294,7 @@ void Trainer::Examples::addContext(std::vector<std::vector<long>> & context) currentExampleIndex += context.size(); } -void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<int> & goldIndexes) +void Trainer::Examples::addClass(const LossFunction & lossFct, int nbClasses, const std::vector<long> & goldIndexes) { auto gold = lossFct.getGoldFromClassesIndexes(nbClasses, goldIndexes);