Skip to content
Snippets Groups Projects
Commit 4cb5c474 authored by Franck Dary's avatar Franck Dary
Browse files

Output entropy when there is a column named 'ENTROPY'

parent 1e98bc42
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ class Beam
int nbTransitions = 0;
double totalProbability{0.0};
bool ended{false};
float entropy{0.0};
public :
......
......@@ -51,7 +51,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
float entropy = classifier.isRegression() ? 0.0 : NeuralNetworkImpl::entropy(prediction);
std::vector<std::pair<float, int>> scoresOfTransitions;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
......@@ -80,6 +80,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
elements.back().config.setChosenActionScore(scoresOfTransitions[i].first);
elements.back().nbTransitions++;
elements.back().meanProbability = elements.back().totalProbability / elements.back().nbTransitions;
elements.back().entropy = entropy;
}
elements[index].nextTransition = scoresOfTransitions[0].second;
......@@ -89,6 +90,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
elements[index].name.push_back("0");
elements[index].meanProbability = 0.0;
elements[index].meanProbability = elements[index].totalProbability / elements[index].nbTransitions;
elements[index].entropy = entropy;
if (debug)
{
......@@ -127,7 +129,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto * transition = machine.getTransitionSet(config.getState()).getTransition(element.nextTransition);
transition->apply(config);
transition->apply(config, element.entropy);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
......
......@@ -43,7 +43,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamTh
if (machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{
machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig);
machine.getTransitionSet(baseConfig.getState()).getTransition("EOS b.0")->apply(baseConfig, 0.0);
if (debug)
{
fmt::print(stderr, "Forcing EOS transition\n");
......
......@@ -43,6 +43,7 @@ class Action
static Action moveCharacterIndex(int movement);
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition);
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
......
......@@ -61,6 +61,7 @@ class Transition
public :
Transition(const std::string & name);
void apply(Config & config, float entropy);
void apply(Config & config);
bool appliable(const Config & config) const;
int getCostDynamic(const Config & config) const;
......
......@@ -184,6 +184,55 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde
return {Type::Write, apply, undo, appliable};
}
Action Action::sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition)
{
auto apply = [colName, lineIndex, addition](Config & config, Action &)
{
// Knuth’s algorithm for online mean
auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get();
auto splited = util::split(curStr, '|');
float curVal = 0.0;
int curNb = 0;
if (splited.size() == 2)
{
curVal = std::stof(splited[0]);
curNb = std::stoi(splited[1]);
}
curNb += 1;
float delta = addition - curVal;
curVal += delta / curNb;
config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb);
};
auto undo = [colName, lineIndex, addition](Config & config, Action &)
{
// Knuth’s algorithm for online mean
auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get();
auto splited = util::split(curStr, '|');
float curVal = 0.0;
int curNb = 0;
if (splited.size() == 2)
{
curVal = std::stof(splited[0]);
curNb = std::stoi(splited[1]);
}
curNb -= 1;
curVal = (curNb*curVal - addition) / (curNb - 1);
config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb);
};
auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)
{
return config.has(colName, lineIndex, 0);
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition)
{
auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
......
......@@ -97,6 +97,17 @@ Transition::Transition(const std::string & name)
}
void Transition::apply(Config & config, float entropy)
{
if (config.hasColIndex("ENTROPY"))
{
auto action = Action::sumToHypothesis("ENTROPY", config.getWordIndex(), entropy);
action.apply(config, action);
}
apply(config);
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
......
......@@ -27,6 +27,8 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
virtual void setDictsState(Dict::State state) = 0;
virtual void setCountOcc(bool countOcc) = 0;
virtual void removeRareDictElements(float rarityThreshold) = 0;
static float entropy(torch::Tensor probabilities);
};
TORCH_MODULE(NeuralNetwork);
......
......@@ -2,3 +2,15 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
float NeuralNetworkImpl::entropy(torch::Tensor probabilities)
{
float res = 0.0;
for (unsigned int i = 0; i < probabilities.size(0); i++)
{
float val = probabilities[i].item<float>();
res -= val * log(val);
}
return res;
}
......@@ -83,11 +83,14 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
int nbClasses = machine.getTransitionSet(config.getState()).size();
float bestScore = -std::numeric_limits<float>::max();
float entropy = 0.0;
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
......@@ -123,6 +126,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
float regressionTarget = 0.0;
if (machine.getClassifier(config.getState())->isRegression())
{
entropy = 0.0;
auto errMessage = fmt::format("Invalid regression transition '{}'", transition->getName());
auto splited = util::split(transition->getName(), ' ');
if (splited.size() != 3 or splited[0] != "WRITESCORE")
......@@ -154,7 +158,7 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
config.setChosenActionScore(bestScore);
transition->apply(config);
transition->apply(config, entropy);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment