From fd4bc1579ac58868289439015b0c6b117b45bd2c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 15 Jul 2020 23:36:22 +0200 Subject: [PATCH] Action EOS adds text metadata --- reading_machine/include/Action.hpp | 1 + reading_machine/include/Config.hpp | 3 ++ reading_machine/src/Action.cpp | 55 ++++++++++++++++++++++++++++-- reading_machine/src/BaseConfig.cpp | 2 ++ reading_machine/src/Config.cpp | 25 ++++++++++---- reading_machine/src/Transition.cpp | 1 + 6 files changed, 77 insertions(+), 10 deletions(-) diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index 6a34c90..71cf7d5 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -39,6 +39,7 @@ class Action public : static Action addLinesIfNeeded(int nbLines); + static Action addMetadataLinesIfNeeded(); static Action moveWordIndex(int movement); static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 71fb1ee..ecea742 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -50,6 +50,7 @@ class Config Utf8String rawInput; std::size_t wordIndex{0}; std::size_t characterIndex{0}; + std::size_t currentSentenceStartRawInput{0}; String state{"NONE"}; boost::circular_buffer<String> history{10}; boost::circular_buffer<std::size_t> stack{50}; @@ -164,6 +165,8 @@ class Config bool isExtraColumn(const std::string & colName) const; void setStrategy(const std::vector<std::string> & strategyDefinition); Strategy & getStrategy(); + std::size_t getCurrentSentenceStartRawInput() const; + void setCurrentSentenceStartRawInput(std::size_t value); }; #endif diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index a996ce8..e04b5f6 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -29,6 +29,39 @@ Action Action::addLinesIfNeeded(int nbLines) return {Type::AddLines, apply, undo, appliable}; } +Action Action::addMetadataLinesIfNeeded() +{ + auto apply = [](Config & config, Action &) + { + if (!config.hasCharacter(0)) + return; + if (config.rawInputOnlySeparatorsLeft()) + return; + + if (!config.has(0, config.getWordIndex()+1, 0)) + config.addLines(1); + if (!config.has(0, config.getWordIndex()+2, 0)) + config.addLines(1); + if (!config.has(0, config.getWordIndex()+3, 0)) + config.addLines(1); + + config.getLastNotEmptyHyp(0, config.getWordIndex()+1) = "#"; + config.getLastNotEmptyHyp(0, config.getWordIndex()+2) = "#"; + }; + + auto undo = [](Config &, Action &) + { + //TODO undo this + }; + + auto appliable = [](const Config &, const Action &) + { + return true; + }; + + return {Type::AddLines, apply, undo, appliable}; +} + Action Action::moveWordIndex(int movement) { auto apply = [movement](Config & config, Action &) @@ -588,7 +621,7 @@ Action Action::updateIds(int bufferIndex) break; util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); } - if (config.isComment(i) || config.isEmptyNode(i)) + if (config.isCommentPredicted(i) || config.isEmptyNode(i)) continue; if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) @@ -605,7 +638,7 @@ Action Action::updateIds(int bufferIndex) for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) { - if (config.isComment(i) || config.isEmptyNode(i)) + if (config.isCommentPredicted(i) || config.isEmptyNode(i)) continue; if (config.isMultiwordPredicted(i)) @@ -615,6 +648,22 @@ Action Action::updateIds(int bufferIndex) config.getFirstEmpty(Config::sentIdColName, i) = fmt::format("{}", lastSentId+1); } + + // Update metadata '# text = ...' and '# sent_id = X' before the sentence + if (config.hasCharacter(0)) + { + if (config.has(0,firstIndexOfSentence-1,0) and config.isCommentPredicted(firstIndexOfSentence-1)) + { + std::string textMetadata = "# text = "; + for (auto i = config.getCurrentSentenceStartRawInput(); i < config.getCharacterIndex(); i++) + textMetadata = fmt::format("{}{}", textMetadata, config.getLetter(i)); + config.getLastNotEmptyHyp(0, firstIndexOfSentence-1) = textMetadata; + } + if (config.has(0,firstIndexOfSentence-2,0) and config.isCommentPredicted(firstIndexOfSentence-2)) + config.getLastNotEmptyHyp(0, firstIndexOfSentence-2) = fmt::format("# sent_id = {}", config.getAsFeature(Config::sentIdColName, firstIndexOfSentence)); + + config.setCurrentSentenceStartRawInput(config.getCharacterIndex()); + } }; auto undo = [](Config & config, Action & a) @@ -774,7 +823,7 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) { - if (config.isComment(i) || config.isEmptyNode(i)) + if (config.isCommentPredicted(i) || config.isEmptyNode(i)) continue; if (config.isMultiwordPredicted(i)) diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index d365cf6..a3223e4 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -116,6 +116,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; get(0, getNbLines()-1, 0) = std::string(line); + getLastNotEmptyHyp(0, getNbLines()-1) = std::string(line); continue; } @@ -174,6 +175,7 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::strin if (!has(0,wordIndex,0)) { + addComment(); addComment(); addLines(1); } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index c967a4f..5d095dc 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -111,9 +111,9 @@ void Config::print(FILE * dest) const for (unsigned int line = 0; line < getNbLines(); line++) { - if (isComment(getFirstLineIndex()+line)) + if (isCommentPredicted(getFirstLineIndex()+line)) { - currentSequenceComments.emplace_back(fmt::format("{}\n", getConst(0, getFirstLineIndex()+line, 0))); + currentSequenceComments.emplace_back(fmt::format("{}\n", getLastNotEmptyHypConst(0, getFirstLineIndex()+line))); continue; } for (unsigned int i = 0; i < getNbColumns()-1; i++) @@ -171,7 +171,7 @@ void Config::printForDebug(FILE * dest) const for (int line = firstLineToPrint; line <= lastLineToPrint; line++) { - if (isComment(line)) + if (isCommentPredicted(line)) continue; toPrint.emplace_back(); toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); @@ -451,6 +451,7 @@ bool Config::moveWordIndex(int relativeMovement) { int nbMovements = 0; int oldVal = wordIndex; + while (nbMovements != relativeMovement) { do @@ -462,7 +463,7 @@ bool Config::moveWordIndex(int relativeMovement) return false; } } - while (isComment(wordIndex)); + while (isCommentPredicted(wordIndex)); nbMovements += relativeMovement > 0 ? 1 : -1; } @@ -481,11 +482,11 @@ void Config::moveWordIndexRelaxed(int relativeMovement) break; wordIndex += increment; } - while (isComment(wordIndex)); + while (isCommentPredicted(wordIndex)); nbMovements += relativeMovement > 0 ? 1 : -1; } - if (!isComment(wordIndex)) + if (!isCommentPredicted(wordIndex)) return; moveWordIndex(-increment); @@ -503,7 +504,7 @@ bool Config::canMoveWordIndex(int relativeMovement) const if (!has(0,oldVal,0)) return false; } - while (isComment(oldVal)); + while (isCommentPredicted(oldVal)); nbMovements += relativeMovement > 0 ? 1 : -1; } @@ -784,3 +785,13 @@ Strategy & Config::getStrategy() return *strategy.get(); } +std::size_t Config::getCurrentSentenceStartRawInput() const +{ + return currentSentenceStartRawInput; +} + +void Config::setCurrentSentenceStartRawInput(std::size_t value) +{ + currentSentenceStartRawInput = value; +} + diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index c8d3804..2b97ca0 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -700,6 +700,7 @@ void Transition::initReduce_relaxed() void Transition::initEOS(int bufferIndex) { + sequence.emplace_back(Action::addMetadataLinesIfNeeded()); sequence.emplace_back(Action::setRoot(bufferIndex)); sequence.emplace_back(Action::updateIds(bufferIndex)); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); -- GitLab