diff --git a/decoder/include/Beam.hpp b/decoder/include/Beam.hpp index 2c34d3f9a81792ddda88506ebb88c4fae944078b..1dd40184acbe7f0adb75a24201939959ca12083d 100644 --- a/decoder/include/Beam.hpp +++ b/decoder/include/Beam.hpp @@ -21,6 +21,7 @@ class Beam int nbTransitions = 0; double totalProbability{0.0}; bool ended{false}; + float entropy{0.0}; public : diff --git a/decoder/src/Beam.cpp b/decoder/src/Beam.cpp index 16d5a32c17d220c75e5c6ab66ab3ca799dbd0e45..32e7360e02fd07bbb6e8c201463312a9e3fb867b 100644 --- a/decoder/src/Beam.cpp +++ b/decoder/src/Beam.cpp @@ -51,7 +51,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0); - + float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction); std::vector<std::pair<float, int>> scoresOfTransitions; for (unsigned int i = 0; i < prediction.size(0); i++) { @@ -80,6 +80,7 @@ void Beam::update(ReadingMachine & machine, bool debug) elements.back().config.setChosenActionScore(scoresOfTransitions[i].first); elements.back().nbTransitions++; elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions; + elements.back().entropy = entropy; } elements[index].nextTransition = scoresOfTransitions[0].second; @@ -89,6 +90,7 @@ void Beam::update(ReadingMachine & machine, bool debug) elements[index].name.push_back("0"); elements[index].meanProbability = 0.0; elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions; + elements[index].entropy = entropy; if (debug) { @@ -127,7 +129,7 @@ void Beam::update(ReadingMachine & machine, bool debug) auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition); - transition->apply(config); + transition->apply(config, element.entropy); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName()); diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index 7739b1de6096be6fedb313b8d81ed974ee493e99..22f9689ccb23687aec0ba4b21184c1fc61622404 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -43,7 +43,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1) { - machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig); + machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig, 0.0); if (debug) { fmt::print(stderr, "Forcing EOS transition\n"); diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 4d88e48118f90fe3edadd920b98e46c1be1a3a60..1ad7ab700c5c837ff83edac8043f9b804fecd37f 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -43,6 +43,7 @@ class Action static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); + static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition); static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis); static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index 3e76c724245800586f78482e0f974689f93b8fb6..75f9c40643ae3487ba54ddddd14349e65cb581f5 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -61,6 +61,7 @@ class Transition public : Transition(const std::string & name); + void apply(Config & config, float entropy); void apply(Config & config); bool appliable(const Config & config) const; int getCostDynamic(const Config & config) const; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index cd4c20cf66d641eb8eea2063cac1ab79c00d44b4..a269968ca658db6f1d2cdaa7aadbcbd173ca8404 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -184,6 +184,55 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde return {Type::Write, apply, undo, appliable}; } +Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition) +{ + auto apply = [colName, lineIndex, addition](Config & config, Action &) + { + // Knuth’s algorithm for online mean + auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + auto splited = util::split(curStr, '|'); + float curVal = 0.0; + int curNb = 0; + if (splited.size() == 2) + { + curVal = std::stof(splited[0]); + curNb = std::stoi(splited[1]); + } + + curNb += 1; + float delta = addition - curVal; + curVal += delta / curNb; + + config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb); + }; + + auto undo = [colName, lineIndex, addition](Config & config, Action &) + { + // Knuth’s algorithm for online mean + auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + auto splited = util::split(curStr, '|'); + float curVal = 0.0; + int curNb = 0; + if (splited.size() == 2) + { + curVal = std::stof(splited[0]); + curNb = std::stoi(splited[1]); + } + + curNb -= 1; + curVal = (curNb*curVal - addition) / (curNb - 1); + + config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb); + }; + + auto appliable = [colName, lineIndex, addition](const Config & config, const Action &) + { + return config.has(colName, lineIndex, 0); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition) { auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 29dcb6cba418d988ca4a9bc1c4a7776995c83131..879c074f8e8de9797097ba849194b8226e07d955 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -97,6 +97,17 @@ Transition::Transition(const std::string & name) } +void Transition::apply(Config & config, float entropy) +{ + if (config.hasColIndex("ENTROPY")) + { + auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy); + action.apply(config, action); + } + + apply(config); +} + void Transition::apply(Config & config) { for (Action & action : sequence) diff --git a/torch_modules/include/NeuralNetwork.hpp b/torch_modules/include/NeuralNetwork.hpp index ee32d2b2eadc666ef7e38ac70b8ed9f64055d3e4..6058cebf9c39dff266622c656aefff66dd094f7b 100644 --- a/torch_modules/include/NeuralNetwork.hpp +++ b/torch_modules/include/NeuralNetwork.hpp @@ -27,6 +27,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St virtual void setDictsState(Dict::State state) = 0; virtual void setCountOcc(bool countOcc) = 0; virtual void removeRareDictElements(float rarityThreshold) = 0; + + static float entropy(torch::Tensor probabilities); }; TORCH_MODULE(NeuralNetwork); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 02e8a191bfb4b2bc718b6e815a266bec252fb24b..dbab2ebdb9a8e38ccd0cd660e2b5a56a10c3b829 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -2,3 +2,15 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +float NeuralNetworkImpl::entropy(torch::Tensor probabilities) +{ + float res = 0.0; + for (unsigned int i = 0; i < probabilities.size(0); i++) + { + float val = probabilities[i].item<float>(); + res -= val * log(val); + } + + return res; +} + diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index bdd5b33515e23c8fa79bbb599666b7c3f10a73ee..0f0ad1a4cfa9fbbb1543d0a58a8943243101d208 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -83,11 +83,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p int nbClasses = machine.getTransitionSet(config.getState()).size(); float bestScore = -std::numeric_limits<float>::max(); + + float entropy = 0.0; if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter") { auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0); + entropy = NeuralNetworkImpl::entropy(prediction); std::vector<int> candidates; @@ -123,6 +126,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p float regressionTarget = 0.0; if (machine.getClassifier(config.getState())->isRegression()) { + entropy = 0.0; auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName()); auto splited = util::split(transition->getName(), ' '); if (splited.size() != 3 or splited[0] != "WRITESCORE") @@ -154,7 +158,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p config.setChosenActionScore(bestScore); - transition->apply(config); + transition->apply(config, entropy); config.addToHistory(transition->getName()); auto movement = config.getStrategy().getMovement(config, transition->getName());