Commit 82afc4f3 authored by Franck Dary's avatar Franck Dary
Browse files

Added more complexity metrics

parent f71baac6
......@@ -44,6 +44,7 @@ class Action
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action sumToHypothesis(const std::string & colName, std::size_t lineIndex, float addition, bool mean);
static Action maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition);
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addHypothesisRelativeRelaxed(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
......
......@@ -295,6 +295,63 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde
return {Type::Write, apply, undo, appliable};
}
Action Action::maxWithHypothesis(const std::string & colName, std::size_t lineIndex, float addition)
{
auto apply = [colName, lineIndex, addition](Config & config, Action &)
{
std::string totalStr = std::string(config.getLastNotEmptyHypConst(colName, lineIndex));
if (totalStr.empty() || totalStr == "_")
totalStr = fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0);
auto byStates = util::split(totalStr, ',');
int index = -1;
for (unsigned int i = 0; i < byStates.size(); i++)
{
auto state = util::split(byStates[i], '=')[0];
if (state == config.getState())
{
index = i;
break;
}
}
if (index == -1)
{
byStates.emplace_back(fmt::format("{}={}|{}", std::string(config.getState()), 0.0, 0));
index = byStates.size()-1;
}
auto splited = util::split(util::split(byStates[index], '=')[1], '|');
float curVal = 0.0;
int curNb = 0;
if (splited.size() == 2)
{
curVal = std::stof(splited[0]);
curNb = std::stoi(splited[1]);
}
curNb += 1;
curVal = addition;
byStates[index] = fmt::format("{}={}|{}", std::string(config.getState()), curVal, curNb);
config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates);
};
auto undo = [](Config &, Action &)
{
//TODO: not done
};
auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)
{
return config.has(colName, lineIndex, 0);
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition)
{
auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
......
......@@ -116,67 +116,232 @@ void Transition::apply(Config & config, float entropy)
curValue = fmt::format("_");
}
}
// Entropy of the action that attach a word to the tree
if (config.hasColIndex("ENTROPY_ATTACH"))
float surprisal = -log(config.getChosenActionScore());
if (config.hasColIndex("ENT_CUR_ADD"))
{
auto action = Action::sumToHypothesis("ENT_CUR_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
}
if (config.hasColIndex("ENT_CUR_MEAN"))
{
auto action = Action::sumToHypothesis("ENT_CUR_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
}
if (config.hasColIndex("ENT_CUR_MAX"))
{
auto action = Action::maxWithHypothesis("ENT_CUR_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
}
if (config.hasColIndex("ENT_ATT_ADD"))
{
bool mean = false;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
if (name.find("REDUCE") != std::string::npos)
{
if (name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getStack(0), entropy, mean);
action.apply(config, action);
}
else
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH", config.getWordIndex(), entropy, mean);
auto action = Action::sumToHypothesis("ENT_ATT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
}
}
if (config.hasColIndex("ENTROPY_ATTACH_MEAN"))
if (config.hasColIndex("ENT_ATT_MEAN"))
{
bool mean = true;
if (name.find("SHIFT") == std::string::npos and name.find("REDUCE") == std::string::npos)
if (name.find("REDUCE") != std::string::npos)
{
if (name.find("LEFT") != std::string::npos)
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getStack(0), entropy, mean);
auto action = Action::sumToHypothesis("ENT_ATT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
}
}
if (config.hasColIndex("ENT_ATT_MAX"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENTROPY_ATTACH_MEAN", config.getWordIndex(), entropy, mean);
auto action = Action::maxWithHypothesis("ENT_ATT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
}
}
// Entropy of every action is taken into account
if (config.hasColIndex("ENTROPY_ALL"))
if (config.hasColIndex("ENT_TGT_ADD"))
{
bool mean = false;
auto action = Action::sumToHypothesis("ENTROPY_ALL", config.getWordIndex(), entropy, mean);
action.apply(config, action);
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getStack(0), entropy, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_TGT_ADD", config.getWordIndex(), entropy, false);
action.apply(config, action);
}
}
if (config.hasColIndex("ENTROPY_ALL_MEAN"))
if (config.hasColIndex("ENT_TGT_MEAN"))
{
bool mean = true;
auto action = Action::sumToHypothesis("ENTROPY_ALL_MEAN", config.getWordIndex(), entropy, mean);
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getStack(0), entropy, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("ENT_TGT_MEAN", config.getWordIndex(), entropy, true);
action.apply(config, action);
}
}
if (config.hasColIndex("ENT_TGT_MAX"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getStack(0), entropy);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("ENT_TGT_MAX", config.getWordIndex(), entropy);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_CUR_ADD"))
{
auto action = Action::sumToHypothesis("SUR_CUR_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
if (config.hasColIndex("SURPRISAL"))
if (config.hasColIndex("SUR_CUR_MEAN"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL", config.getWordIndex(), surprisal, false);
auto action = Action::sumToHypothesis("SUR_CUR_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
}
if (config.hasColIndex("SURPRISAL_MEAN"))
if (config.hasColIndex("SUR_CUR_MAX"))
{
float surprisal = -log(config.getChosenActionScore());
auto action = Action::sumToHypothesis("SURPRISAL_MEAN", config.getWordIndex(), surprisal, true);
auto action = Action::maxWithHypothesis("SUR_CUR_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
}
if (config.hasColIndex("SUR_ATT_ADD"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_ATT_MEAN"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_ATT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_ATT_MAX"))
{
if (name.find("REDUCE") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
if (name.find("LEFT") != std::string::npos || name.find("RIGHT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_ATT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
}
}
if (config.hasColIndex("SUR_TGT_ADD"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getStack(0), surprisal, false);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_TGT_ADD", config.getWordIndex(), surprisal, false);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_TGT_MEAN"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getStack(0), surprisal, true);
action.apply(config, action);
}
else
{
auto action = Action::sumToHypothesis("SUR_TGT_MEAN", config.getWordIndex(), surprisal, true);
action.apply(config, action);
}
}
if (config.hasColIndex("SUR_TGT_MAX"))
{
if (name.find("REDUCE") != std::string::npos || name.find("LEFT") != std::string::npos)
{
auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getStack(0), surprisal);
action.apply(config, action);
}
else
{
auto action = Action::maxWithHypothesis("SUR_TGT_MAX", config.getWordIndex(), surprisal);
action.apply(config, action);
}
}
apply(config);
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment