diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index dbf344b1e3c24acacc60ef0b2a5da3368f49e1a3..4d88e48118f90fe3edadd920b98e46c1be1a3a60 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -39,7 +39,6 @@ class Action public : static Action addLinesIfNeeded(int nbLines); - static Action addMetadataLinesIfNeeded(); static Action moveWordIndex(int movement); static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 7e660d399080fc4cee2f0cc8410705ec4dc95b9f..5a5e72e24efe8cb7cf90a5de0a16394f93ebabcd 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -25,6 +25,9 @@ class Config static constexpr const char * sentIdColName = "SENTID"; static constexpr const char * isMultiColName = "MULTI"; static constexpr const char * childsColName = "CHILDS"; + static constexpr const char * commentsColName = "COMMENTS"; + static constexpr const char * rawRangeStartColName = "RAWSTART"; + static constexpr const char * rawRangeEndColName = "RAWEND"; static constexpr int nbHypothesesMax = 1; static constexpr int maxNbAppliableSplitTransitions = 8; @@ -55,7 +58,7 @@ class Config boost::circular_buffer<String> history{10}; boost::circular_buffer<std::size_t> stack{50}; float chosenActionScore{0.0}; - std::vector<std::string> extraColumns{isMultiColName, childsColName, sentIdColName, EOSColName}; + std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName}; std::set<std::string> predicted; int lastPoppedStack{-1}; int lastAttached{-1}; @@ -122,8 +125,6 @@ class Config void addToStack(std::size_t index); void popStack(); void swapStack(int relIndex1, int relIndex2); - bool isComment(std::size_t lineIndex) const; - bool isCommentPredicted(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const; bool isMultiwordPredicted(std::size_t lineIndex) const; int getMultiwordSize(std::size_t lineIndex) const; @@ -158,7 +159,6 @@ class Config int getCurrentWordId() const; void setCurrentWordId(int currentWordId); void addMissingColumns(); - void addComment(); void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions); void setAppliableTransitions(const std::vector<int> & appliableTransitions); const std::vector<Transition *> & getAppliableSplitTransitions() const; @@ -166,8 +166,6 @@ class Config bool isExtraColumn(const std::string & colName) const; void setStrategy(const std::vector<std::string> & strategyDefinition); Strategy & getStrategy(); - std::size_t getCurrentSentenceStartRawInput() const; - void setCurrentSentenceStartRawInput(std::size_t value); void setChosenActionScore(float chosenActionScore); float getChosenActionScore() const; }; diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index a6acb5228b0fc018cf4d6489e1ea49d63898ecde..40033bebcfd1206e5fc8e22cb301a6dcb36f400f 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -29,39 +29,6 @@ Action Action::addLinesIfNeeded(int nbLines) return {Type::AddLines, apply, undo, appliable}; } -Action Action::addMetadataLinesIfNeeded() -{ - auto apply = [](Config & config, Action &) - { - if (!config.hasCharacter(0)) - return; - if (config.rawInputOnlySeparatorsLeft()) - return; - - if (!config.has(0, config.getWordIndex()+1, 0)) - config.addLines(1); - if (!config.has(0, config.getWordIndex()+2, 0)) - config.addLines(1); - if (!config.has(0, config.getWordIndex()+3, 0)) - config.addLines(1); - - config.getLastNotEmptyHyp(0, config.getWordIndex()+1) = "#"; - config.getLastNotEmptyHyp(0, config.getWordIndex()+2) = "#"; - }; - - auto undo = [](Config &, Action &) - { - //TODO undo this - }; - - auto appliable = [](const Config &, const Action &) - { - return true; - }; - - return {Type::AddLines, apply, undo, appliable}; -} - Action Action::moveWordIndex(int movement) { auto apply = [movement](Config & config, Action &) @@ -496,18 +463,33 @@ Action Action::addCharsToCol(const std::string & col, int n, Config::Object obje { auto apply = [col, n, object, relativeIndex](Config & config, Action & a) { - auto & curWord = config.getLastNotEmptyHyp(col, config.getRelativeWordIndex(object, relativeIndex)); + auto index = config.getRelativeWordIndex(object, relativeIndex); + auto & curWord = config.getLastNotEmptyHyp(col, index); + if (col == "FORM") + { + if (util::isEmpty(config.getAsFeature(Config::rawRangeStartColName, index))) + config.getLastNotEmptyHyp(Config::rawRangeStartColName, index) = fmt::format("{}", config.getCharacterIndex()); + if (util::isEmpty(config.getAsFeature(Config::rawRangeEndColName, index))) + config.getLastNotEmptyHyp(Config::rawRangeEndColName, index) = fmt::format("{}", config.getCharacterIndex()); + int curEndValue = std::stoi(config.getAsFeature(Config::rawRangeEndColName, index)); + config.getLastNotEmptyHyp(Config::rawRangeEndColName, index) = fmt::format("{}", curEndValue+n); + } for (int i = 0; i < n; i++) curWord = fmt::format("{}{}", curWord, config.getLetter(config.getCharacterIndex()+i)); }; auto undo = [col, n, object, relativeIndex](Config & config, Action & a) { - auto & curWord = config.getLastNotEmptyHyp(col, config.getRelativeWordIndex(object, relativeIndex)); + auto index = config.getRelativeWordIndex(object, relativeIndex); + auto & curWord = config.getLastNotEmptyHyp(col, index); auto newWord = util::splitAsUtf8(curWord.get()); for (int i = 0; i < n; i++) newWord.pop_back(); curWord = fmt::format("{}", newWord); + if (newWord.size() == 0) + config.getLastNotEmptyHyp(Config::rawRangeStartColName, index) = "0"; + int curEndValue = std::stoi(config.getAsFeature(Config::rawRangeEndColName, index)); + config.getLastNotEmptyHyp(Config::rawRangeEndColName, index) = fmt::format("{}", curEndValue-n); }; auto appliable = [col, n, object, relativeIndex](const Config & config, const Action &) @@ -576,7 +558,7 @@ Action Action::setRoot(int bufferIndex) { if (i == rootIndex) { - config.getFirstEmpty(Config::headColName, i) = "0"; + config.getFirstEmpty(Config::headColName, i) = "-1"; config.getFirstEmpty(Config::deprelColName, i) = "root"; } else @@ -621,7 +603,7 @@ Action Action::updateIds(int bufferIndex) break; util::myThrow("The current sentence is too long to be completly held by the data strucure. Consider increasing SubConfig::SpanSize"); } - if (config.isCommentPredicted(i) || config.isEmptyNode(i)) + if (config.isEmptyNode(i)) continue; if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1) @@ -638,7 +620,7 @@ Action Action::updateIds(int bufferIndex) for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) { - if (config.isCommentPredicted(i) || config.isEmptyNode(i)) + if (config.isEmptyNode(i)) continue; if (config.isMultiwordPredicted(i)) @@ -651,19 +633,13 @@ Action Action::updateIds(int bufferIndex) // Update metadata '# text = ...' and '# sent_id = X' before the sentence if (config.hasCharacter(0)) - { - if (config.has(0,firstIndexOfSentence-1,0) and config.isCommentPredicted(firstIndexOfSentence-1)) + if (config.has(0,firstIndexOfSentence,0)) { std::string textMetadata = "# text = "; - for (auto i = config.getCurrentSentenceStartRawInput(); i < config.getCharacterIndex(); i++) + for (auto i = std::stoi(config.getAsFeature(Config::rawRangeStartColName, firstIndexOfSentence)); i < std::stoi(config.getAsFeature(Config::rawRangeEndColName, lineIndex)); i++) textMetadata = fmt::format("{}{}", textMetadata, config.getLetter(i)); - config.getLastNotEmptyHyp(0, firstIndexOfSentence-1) = textMetadata; + config.getLastNotEmptyHyp(Config::commentsColName, firstIndexOfSentence) = fmt::format("{}\n# sent_id = {}", textMetadata, config.getAsFeature(Config::sentIdColName, firstIndexOfSentence)); } - if (config.has(0,firstIndexOfSentence-2,0) and config.isCommentPredicted(firstIndexOfSentence-2)) - config.getLastNotEmptyHyp(0, firstIndexOfSentence-2) = fmt::format("# sent_id = {}", config.getAsFeature(Config::sentIdColName, firstIndexOfSentence)); - - config.setCurrentSentenceStartRawInput(config.getCharacterIndex()); - } }; auto undo = [](Config & config, Action & a) @@ -823,7 +799,7 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged() for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i) { - if (config.isCommentPredicted(i) || config.isEmptyNode(i)) + if (config.isEmptyNode(i)) continue; if (config.isMultiwordPredicted(i)) diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index a061c76909f9dd100c486e20d27d6eda9d5ae689..abb229e42701acb7594b7dc776d227dcdbdeb900 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -55,6 +55,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) bool inputHasBeenRead = false; int usualNbCol = -1; int nbMultiwords = 0; + std::vector<std::string> pendingComments; while (!std::feof(file)) { @@ -93,8 +94,9 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) continue; auto & head = get(headColName, i, 0); if (head == "0") - continue; - head = std::to_string(id2index[head]); + head = "-1"; + else + head = std::to_string(id2index[head]); } } catch(std::exception & e) {util::myThrow(e.what());} @@ -112,11 +114,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) })) continue; - addLines(1); - get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; - get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; - get(0, getNbLines()-1, 0) = std::string(line); - getLastNotEmptyHyp(0, getNbLines()-1) = std::string(line); + pendingComments.emplace_back(line); continue; } @@ -142,6 +140,9 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) else get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; + get(commentsColName, getNbLines()-1, 0) = util::join("\n", pendingComments); + pendingComments.clear(); + for (unsigned int i = 0; i < splited.size(); i++) if (i < colIndex2Name.size() - extraColumns.size()) { @@ -174,14 +175,7 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::strin readTSVInput(tsvFilename); if (!has(0,wordIndex,0)) - { - addComment(); - addComment(); addLines(1); - } - - if (isComment(wordIndex)) - moveWordIndex(1); } std::size_t BaseConfig::getNbColumns() const diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 7ad813bec25275b9defe30d7bb4ffbaca0674977..5b37d2c1edba4275e127bda9ce4157b2d66c3f39 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -38,13 +38,6 @@ void Config::addLines(unsigned int nbLines) lines.resize(lines.size() + nbLines*getNbColumns()*(nbHypothesesMax+1)); } -void Config::addComment() -{ - lines.resize(lines.size() + getNbColumns()*(nbHypothesesMax+1)); - get(0, getNbLines()-1, 0) = "#"; - getLastNotEmptyHyp(0, getNbLines()-1) = "#"; -} - void Config::resizeLines(unsigned int nbLines) { lines.resize(nbLines*getNbColumns()*(nbHypothesesMax+1)); @@ -111,11 +104,9 @@ void Config::print(FILE * dest) const for (unsigned int line = 0; line < getNbLines(); line++) { - if (isCommentPredicted(getFirstLineIndex()+line)) - { - currentSequenceComments.emplace_back(fmt::format("{}\n", getLastNotEmptyHypConst(0, getFirstLineIndex()+line))); - continue; - } + if (!util::isEmpty(getAsFeature(commentsColName, getFirstLineIndex()+line))) + currentSequenceComments.emplace_back(fmt::format("{}\n", getAsFeature(commentsColName, getFirstLineIndex()+line))); + for (unsigned int i = 0; i < getNbColumns()-1; i++) { if (isExtraColumn(getColName(i)) and getColName(i) != EOSColName) @@ -129,8 +120,12 @@ void Config::print(FILE * dest) const try { if (getColName(i) == headColName) - if (valueToPrint != "0") + { + if (valueToPrint != "-1") valueToPrint = getAsFeature(idColName, std::stoi(valueToPrint)); + else + valueToPrint = "0"; + } } catch(std::exception &) {} if (valueToPrint.empty()) valueToPrint = "_"; @@ -171,8 +166,6 @@ void Config::printForDebug(FILE * dest) const for (int line = firstLineToPrint; line <= lastLineToPrint; line++) { - if (isCommentPredicted(line)) - continue; toPrint.emplace_back(); toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); for (unsigned int i = 0; i < getNbColumns(); i++) @@ -184,8 +177,12 @@ void Config::printForDebug(FILE * dest) const try { if (getColName(i) == headColName && toPrintCol != "_" && !toPrintCol.empty()) - if (toPrintCol != "0" && toPrintCol != "?") + { + if (toPrintCol != "-1" && toPrintCol != "?") toPrintCol = has(0,std::stoi(toPrintCol),0) ? getAsFeature(idColName, std::stoi(toPrintCol)).get() : "?"; + else if (toPrintCol == "-1") + toPrintCol = "0"; + } } catch(std::exception & e) {util::myThrow(fmt::format("toPrintCol='{}' {}", toPrintCol, e.what()));} toPrint.back().emplace_back(util::shrink(toPrintCol, maxWordLength)); } @@ -392,19 +389,6 @@ util::utf8char Config::getLetter(int letterIndex) const return rawInput[letterIndex]; } -bool Config::isComment(std::size_t lineIndex) const -{ - auto iter = getConstIterator(0, lineIndex, 0); - return !iter->get().empty() and iter->get()[0] == '#'; -} - -bool Config::isCommentPredicted(std::size_t lineIndex) const -{ - auto & col0Pred = getAsFeature(0, lineIndex); - auto & col0Gold = getConst(0, lineIndex, 0); - return (!util::isEmpty(col0Pred) and col0Pred.get()[0] == '#') or (!util::isEmpty(col0Gold) and col0Gold.get()[0] == '#'); -} - bool Config::isMultiword(std::size_t lineIndex) const { return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('-') != std::string::npos; @@ -439,76 +423,32 @@ bool Config::isEmptyNodePredicted(std::size_t lineIndex) const bool Config::isToken(std::size_t lineIndex) const { - return !isComment(lineIndex) && !isMultiword(lineIndex) && !isEmptyNode(lineIndex); + return !isMultiword(lineIndex) && !isEmptyNode(lineIndex); } bool Config::isTokenPredicted(std::size_t lineIndex) const { - return !isCommentPredicted(lineIndex) && !isMultiwordPredicted(lineIndex) && !isEmptyNodePredicted(lineIndex); + return !isMultiwordPredicted(lineIndex) && !isEmptyNodePredicted(lineIndex); } bool Config::moveWordIndex(int relativeMovement) { - int nbMovements = 0; - int oldVal = wordIndex; + if (!canMoveWordIndex(relativeMovement)) + return false; - while (nbMovements != relativeMovement) - { - do - { - relativeMovement > 0 ? wordIndex++ : wordIndex--; - if (!has(0,wordIndex,0)) - { - wordIndex = oldVal; - return false; - } - } - while (isCommentPredicted(wordIndex)); - nbMovements += relativeMovement > 0 ? 1 : -1; - } + wordIndex += relativeMovement; return true; } void Config::moveWordIndexRelaxed(int relativeMovement) { - int nbMovements = 0; - int increment = relativeMovement > 0 ? 1 : -1; - while (nbMovements != relativeMovement) - { - do - { - if (!has(0,wordIndex+increment,0)) - break; - wordIndex += increment; - } - while (isCommentPredicted(wordIndex)); - nbMovements += relativeMovement > 0 ? 1 : -1; - } - - if (!isCommentPredicted(wordIndex)) - return; - - moveWordIndex(-increment); + while (!moveWordIndex(relativeMovement)); } bool Config::canMoveWordIndex(int relativeMovement) const { - int nbMovements = 0; - int oldVal = wordIndex; - while (nbMovements != relativeMovement) - { - do - { - relativeMovement > 0 ? oldVal++ : oldVal--; - if (!has(0,oldVal,0)) - return false; - } - while (isCommentPredicted(oldVal)); - nbMovements += relativeMovement > 0 ? 1 : -1; - } - - return true; + return has(0,wordIndex+relativeMovement,0); } bool Config::moveCharacterIndex(int relativeMovement) @@ -601,7 +541,11 @@ void Config::addPredicted(const std::set<std::string> & predicted) for (auto & col : extraColumns) if (col != EOSColName) + { + if (col == commentsColName and !isPredicted(idColName)) + continue; this->predicted.insert(col); + } } bool Config::isPredicted(const std::string & colName) const @@ -646,60 +590,21 @@ void Config::addMissingColumns() if (hasColIndex(headColName)) if (util::isEmpty(getAsFeature(headColName, index))) - getLastNotEmptyHyp(headColName, index) = (curId == 1) ? "0" : std::to_string(firstIndex); + getLastNotEmptyHyp(headColName, index) = (curId == 1) ? "-1" : std::to_string(firstIndex); } } long Config::getRelativeWordIndex(int relativeIndex) const { - if (relativeIndex < 0) - { - for (int index = getWordIndex()-1, counter = 0; has(0,index,0); --index) - if (!isCommentPredicted(index)) - { - --counter; - if (counter == relativeIndex) - return index; - } - } - else - { - for (int index = getWordIndex(), counter = 0; has(0,index,0); ++index) - if (!isCommentPredicted(index)) - { - if (counter == relativeIndex) - return index; - ++counter; - } - } + if (has(0,wordIndex+relativeIndex,0)) + return wordIndex+relativeIndex; return -1; } long Config::getRelativeDistance(int fromIndex, int toIndex) const { - if (toIndex < fromIndex) - { - for (int index = fromIndex, counter = 0; has(0,index,0); --index) - if (!isCommentPredicted(index)) - { - if (index == toIndex) - return counter; - --counter; - } - } - else - { - for (int index = fromIndex, counter = 0; has(0,index,0); ++index) - if (!isCommentPredicted(index)) - { - if (index == toIndex) - return counter; - ++counter; - } - } - - return 0; + return fromIndex - toIndex; } long Config::getRelativeWordIndex(Object object, int relativeIndex) const @@ -785,16 +690,6 @@ Strategy & Config::getStrategy() return *strategy.get(); } -std::size_t Config::getCurrentSentenceStartRawInput() const -{ - return currentSentenceStartRawInput; -} - -void Config::setCurrentSentenceStartRawInput(std::size_t value) -{ - currentSentenceStartRawInput = value; -} - void Config::setChosenActionScore(float chosenActionScore) { this->chosenActionScore = chosenActionScore; diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index 23097ff5e9bbe36d8adc11396d58a6979781d994..29dcb6cba418d988ca4a9bc1c4a7776995c83131 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -717,7 +717,6 @@ void Transition::initReduce_relaxed() void Transition::initEOS(int bufferIndex) { - sequence.emplace_back(Action::addMetadataLinesIfNeeded()); sequence.emplace_back(Action::setRoot(bufferIndex)); sequence.emplace_back(Action::updateIds(bufferIndex)); sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1)); diff --git a/torch_modules/src/ContextModule.cpp b/torch_modules/src/ContextModule.cpp index 6629b18eb876b400af5024d03334640e03951b08..05e6823e38321c622498c5cfeaeff18e8c8a6103 100644 --- a/torch_modules/src/ContextModule.cpp +++ b/torch_modules/src/ContextModule.cpp @@ -134,9 +134,7 @@ void ContextModuleImpl::addToContext(std::vector<std::vector<long>> & context, c if (col == Config::idColName) { std::string value; - if (config.isCommentPredicted(index)) - value = "comment"; - else if (config.isMultiwordPredicted(index)) + if (config.isMultiwordPredicted(index)) value = "multiword"; else if (config.isTokenPredicted(index)) value = "token"; diff --git a/torch_modules/src/ContextualModule.cpp b/torch_modules/src/ContextualModule.cpp index 4829596734cd484ca4cc0552fffbb693eb567770..11537be41d3055f6a74423412cdc8b836a6d0e1a 100644 --- a/torch_modules/src/ContextualModule.cpp +++ b/torch_modules/src/ContextualModule.cpp @@ -151,9 +151,7 @@ void ContextualModuleImpl::addToContext(std::vector<std::vector<long>> & context if (col == Config::idColName) { std::string value; - if (config.isCommentPredicted(index)) - value = "comment"; - else if (config.isMultiwordPredicted(index)) + if (config.isMultiwordPredicted(index)) value = "multiword"; else if (config.isTokenPredicted(index)) value = "token";