diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 2dc319125094dd21e67a6db2305620be1161ad2b..4aa259b3a25d1cd1b6e7ba395433098158386b7f 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -51,13 +51,15 @@ class Action static Action setRoot(int bufferIndex); static Action updateIds(int bufferIndex); static Action endWord(); - static Action assertIsEmpty(const std::string & colName); + static Action assertIsEmpty(const std::string & colName, Config::Object object, int relativeIndex); + static Action assertIsNotEmpty(const std::string & colName, Config::Object object, int relativeIndex); static Action attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex); static Action addCurCharToCurWord(); static Action ignoreCurrentCharacter(); static Action consumeCharacterIndex(util::utf8string consumed); static Action setMultiwordIds(int multiwordSize); static Action split(int index); + static Action setRootUpdateIdsEmptyStackIfSentChanged(); }; #endif diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 95871543b139b54ad35ce53c3e3798db1a4b13a2..9095082485b2f840222c3552cea0c810c40b8904 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -368,7 +368,7 @@ Action Action::ignoreCurrentCharacter() return {Type::MoveChar, apply, undo, appliable}; } -Action Action::assertIsEmpty(const std::string & colName) +Action Action::assertIsEmpty(const std::string & colName, Config::Object object, int relativeIndex) { auto apply = [](Config &, Action &) { @@ -378,9 +378,29 @@ Action Action::assertIsEmpty(const std::string & colName) { }; - auto appliable = [colName](const Config & config, const Action &) + auto appliable = [colName, object, relativeIndex](const Config & config, const Action &) { - return util::isEmpty(config.getAsFeature(colName, config.getWordIndex())); + auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return util::isEmpty(config.getAsFeature(colName, lineIndex)); + }; + + return {Type::Check, apply, undo, appliable}; +} + +Action Action::assertIsNotEmpty(const std::string & colName, Config::Object object, int relativeIndex) +{ + auto apply = [](Config &, Action &) + { + }; + + auto undo = [](Config &, Action &) + { + }; + + auto appliable = [colName, object, relativeIndex](const Config & config, const Action &) + { + auto lineIndex = config.getRelativeWordIndex(object, relativeIndex); + return !util::isEmpty(config.getAsFeature(colName, lineIndex)); }; return {Type::Check, apply, undo, appliable}; @@ -626,3 +646,97 @@ Action Action::split(int index) return {Type::Write, apply, undo, appliable}; } +Action Action::setRootUpdateIdsEmptyStackIfSentChanged() +{ + auto apply = [](Config & config, Action & a) + { + int lineIndex = config.getWordIndex(); + int rootIndex = -1; + int lastSentId = -1; + int firstIndexOfSentence = lineIndex; + + if (config.getAsFeature(Config::EOSColName, lineIndex) != Config::EOSSymbol1) + return; + + for (int i = lineIndex-1; true; --i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isTokenPredicted(i)) + continue; + + if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1) + { + lastSentId = std::stoi(config.getAsFeature(Config::sentIdColName, i)); + break; + } + + if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + rootIndex = i; + + firstIndexOfSentence = i; + } + + if (firstIndexOfSentence < 0) + util::myThrow("could not find any token in current sentence"); + + for (int i = firstIndexOfSentence; i <= lineIndex; ++i) + { + if (!config.has(0, i, 0)) + { + if (i < 0) + break; + util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); + } + if (!config.isTokenPredicted(i)) + continue; + + if (util::isEmpty(config.getAsFeature(Config::headColName, i))) + { + if (i == rootIndex) + { + config.getFirstEmpty(Config::headColName, i) = "0"; + config.getFirstEmpty(Config::deprelColName, i) = "root"; + } + else + { + config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex); + } + } + } + + for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) + { + if (config.isComment(i) || config.isEmptyNode(i)) + continue; + + if (config.isMultiwordPredicted(i)) + config.getFirstEmpty(Config::idColName, i) = fmt::format("{}-{}", currentId, currentId+config.getMultiwordSizePredicted(i)); + else + config.getFirstEmpty(Config::idColName, i) = fmt::format("{}", currentId++); + + config.getFirstEmpty(Config::sentIdColName, i) = fmt::format("{}", lastSentId+1); + } + + while (config.hasStack(0)) + config.popStack(); + }; + + auto undo = [](Config & config, Action & a) + { + //TODO undo this + }; + + auto appliable = [](const Config & config, const Action &) + { + int lineIndex = config.getWordIndex(); + return config.has(0,lineIndex,0); + }; + + return {Type::Write, apply, undo, appliable}; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index c9561569e5421f1190a7310a252f738fcb66cc4e..a159f39a6863f3cf061ff26afc823cf1fca68641 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -158,7 +158,7 @@ void Transition::initEndWord() void Transition::initAddCharToWord() { - sequence.emplace_back(Action::assertIsEmpty(Config::idColName)); + sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addLinesIfNeeded(0)); sequence.emplace_back(Action::addCurCharToCurWord()); sequence.emplace_back(Action::moveCharacterIndex(1)); @@ -185,8 +185,8 @@ void Transition::initAddCharToWord() void Transition::initSplitWord(std::vector<std::string> words) { auto consumedWord = util::splitAsUtf8(words[0]); - sequence.emplace_back(Action::assertIsEmpty(Config::idColName)); - sequence.emplace_back(Action::assertIsEmpty("FORM")); + sequence.emplace_back(Action::assertIsEmpty(Config::idColName, Config::Object::Buffer, 0)); + sequence.emplace_back(Action::assertIsEmpty("FORM", Config::Object::Buffer, 0)); sequence.emplace_back(Action::addLinesIfNeeded(words.size())); sequence.emplace_back(Action::consumeCharacterIndex(consumedWord)); for (unsigned int i = 0; i < words.size(); i++) @@ -228,6 +228,7 @@ void Transition::initSplit(int index) void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [](const Config & config) { @@ -303,6 +304,7 @@ void Transition::initRight(std::string label) sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0)); sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label)); sequence.emplace_back(Action::pushWordIndexOnStack()); + sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged()); cost = [label](const Config & config) { @@ -354,6 +356,7 @@ void Transition::initRight(std::string label) void Transition::initReduce() { + sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0)); sequence.emplace_back(Action::popStack()); cost = [](const Config & config)