diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index a269968ca658db6f1d2cdaa7aadbcbd173ca8404..a5592db340af4aa95eda5914eee63859d0fe05bf 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -188,9 +188,30 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde { auto apply = [colName, lineIndex, addition](Config & config, Action &) { + std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + + if (totalStr.empty() || totalStr == "_") + totalStr = fmt::format("{}={}|{}", config.getState(), 0.0, 0); + + auto byStates = util::split(totalStr, ','); + int index = -1; + for (unsigned int i = 0; i < byStates.size(); i++) + { + auto state = util::split(byStates[i], '=')[0]; + if (state == config.getState()) + { + index = i; + break; + } + } + if (index == -1) + { + byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); + index = byStates.size()-1; + } + // Knuth’s algorithm for online mean - auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); - auto splited = util::split(curStr, '|'); + auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; if (splited.size() == 2) @@ -203,14 +224,37 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde float delta = addition - curVal; curVal += delta / curNb; - config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb); + byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); + + config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; auto undo = [colName, lineIndex, addition](Config & config, Action &) { + std::string totalStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); + + if (totalStr.empty() || totalStr == "_") + totalStr = fmt::format("{}={}|{}", config.getState(), 0.0, 0); + + auto byStates = util::split(totalStr, ','); + int index = -1; + for (unsigned int i = 0; i < byStates.size(); i++) + { + auto state = util::split(byStates[i], '=')[0]; + if (state == config.getState()) + { + index = i; + break; + } + } + if (index == -1) + { + byStates.emplace_back(fmt::format("{}={}|{}", config.getState(), 0.0, 0)); + index = byStates.size()-1; + } + // Knuth’s algorithm for online mean - auto curStr = config.getLastNotEmptyHypConst(colName, lineIndex).get(); - auto splited = util::split(curStr, '|'); + auto splited = util::split(util::split(byStates[index], '=')[1], '|'); float curVal = 0.0; int curNb = 0; if (splited.size() == 2) @@ -222,7 +266,9 @@ Action Action::sumToHypothesis(const std::string & colName, std::size_t lineInde curNb -= 1; curVal = (curNb*curVal - addition) / (curNb - 1); - config.getLastNotEmptyHyp(colName, lineIndex) = fmt::format("{}|{}", curVal, curNb); + byStates[index] = fmt::format("{}={}|{}", config.getState(), curVal, curNb); + + config.getLastNotEmptyHyp(colName, lineIndex) = util::join(",", byStates); }; auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)