Skip to content
Snippets Groups Projects
Commit 82afc4f3 authored by Franck Dary's avatar Franck Dary
Browse files

Added more complexity metrics

parent f71baac6
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,7 @@ class Action
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean);
static Action maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition);
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
......
......@@ -295,6 +295,63 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
return {Type::Write, apply, undo, appliable};
}
Action Action::maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition)
{
auto apply = [colName, lineIndex, addition](Config & config, Action &)
{
std::string totalStr = std::string(config.getLastNotEmptyHypConst(colName, lineIndex));
if (totalStr.empty() || totalStr == "_")
totalStr = fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0);
auto byStates = util::split(totalStr, ',');
int index = -1;
for (unsigned int i = 0; i < byStates.size(); i++)
{
auto state = util::split(byStates[i], '=')[0];
if (state == config.getState())
{
index = i;
break;
}
}
if (index == -1)
{
byStates.emplace_back(fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0));
index = byStates.size()-1;
}
auto splited = util::split(util::split(byStates[index], '=')[1], '|');
float curVal = 0.0;
int curNb = 0;
if (splited.size() == 2)
{
curVal = std::stof(splited[0]);
curNb = std::stoi(splited[1]);
}
curNb += 1;
curVal = addition;
byStates[index] = fmt::format("{}={}|{}", std::string(config.getState()), curVal, curNb);
config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates);
};
auto undo = [](Config &, Action &)
{
//TODO: not done
};
auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)
{
return config.has(colName, lineIndex, 0);
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition)
{
auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
......
......@@ -116,67 +116,232 @@ void Transition::apply(Config & config, float entropy)
curValue = fmt::format("_");
}
}
// Entropy of the action that attach a word to the tree
if (config.hasColIndex("ENTROPY_ATTACH"))
float surprisal = -log(config.getChosenActionScore());
if (config.hasColIndex("ENT_CUR_ADD"))
{
auto action = Action::sumToHypothesis("ENT_CUR_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
}
if (config.hasColIndex("ENT_CUR_MEAN"))
{
auto action = Action::sumToHypothesis("ENT_CUR_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
}
if (config.hasColIndex("ENT_CUR_MAX"))
{
auto action = Action::maxWithHypothesis("ENT_CUR_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
}
if (config.hasColIndex("ENT_ATT_ADD"))
{
bool mean = false;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
if (name.find("REDUCE") != std::string::npos)
{
if (name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getStack(0), entropy, mean);
action.apply(config, action);
}
else
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getWordIndex(), entropy, mean);
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
}
}
if (config.hasColIndex("ENTROPY_ATTACH_MEAN"))
if (config.hasColIndex("ENT_ATT_MEAN"))
{
bool mean = true;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
if (name.find("REDUCE") != std::string::npos)
{
if (name.find("LEFT") != std::string::npos)
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getStack(0), entropy, mean);
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
}
}
if (config.hasColIndex("ENT_ATT_MAX"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getWordIndex(), entropy, mean);
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
}
}
// Entropy of every action is taken into account
if (config.hasColIndex("ENTROPY_ALL"))
if (config.hasColIndex("ENT_TGT_ADD"))
{
bool mean = false;
auto action = Action::sumToHypothesis("ENTROPY_ALL", config.getWordIndex(), entropy, mean);
action.apply(config, action);
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
}
}
if (config.hasColIndex("ENTROPY_ALL_MEAN"))
if (config.hasColIndex("ENT_TGT_MEAN"))
{
bool mean = true;
auto action = Action::sumToHypothesis("ENTROPY_ALL_MEAN", config.getWordIndex(), entropy, mean);
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
}
}
if (config.hasColIndex("ENT_TGT_MAX"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_CUR_ADD"))
{
auto action = Action::sumToHypothesis("SUR_CUR_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
if (config.hasColIndex("SURPRISAL"))
if (config.hasColIndex("SUR_CUR_MEAN"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false);
auto action = Action::sumToHypothesis("SUR_CUR_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
}
if (config.hasColIndex("SURPRISAL_MEAN"))
if (config.hasColIndex("SUR_CUR_MAX"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL_MEAN", config.getWordIndex(), surprisal, true);
auto action = Action::maxWithHypothesis("SUR_CUR_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
}
if (config.hasColIndex("SUR_ATT_ADD"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_ATT_MEAN"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_ATT_MAX"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_TGT_ADD"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_TGT_MEAN"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_TGT_MAX"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
}
}
apply(config);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment