diff --git a/decoder/src/Decoder.cpp b/decoder/src/Decoder.cpp index cb9937edbf78a1bb657a81c544468a8257f3c708..db0444a0d07cc329f8359011698d60417f50f8e2 100644 --- a/decoder/src/Decoder.cpp +++ b/decoder/src/Decoder.cpp @@ -88,6 +88,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool if (debug) fmt::print(stderr, "Forcing EOS transition\n"); } + + // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script + try {config.addMissingColumns();} + catch (std::exception & e) {util::myThrow(e.what());} } float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const @@ -145,7 +149,8 @@ std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std: std::vector<std::pair<float, std::string>> scores; for (auto & colName : colNames) - scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName))); + if (colName != Config::idColName) + scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName))); return scores; } @@ -160,6 +165,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const return "Sentences"; if (colName == "FEATS") return "UFeats"; + if (colName == "FORM") + return "Words"; return colName; } diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index de65870764f7a82f515fdcd4c5d9ab1bcf2fb605..13e1806f5afce69d2c3f475c79d2f3f0cced6242 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -93,10 +93,15 @@ class Config void addToStack(std::size_t index); void popStack(); bool isComment(std::size_t lineIndex) const; + bool isCommentPredicted(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const; + bool isMultiwordPredicted(std::size_t lineIndex) const; int getMultiwordSize(std::size_t lineIndex) const; + int getMultiwordSizePredicted(std::size_t lineIndex) const; bool isEmptyNode(std::size_t lineIndex) const; + bool isEmptyNodePredicted(std::size_t lineIndex) const; bool isToken(std::size_t lineIndex) const; + bool isTokenPredicted(std::size_t lineIndex) const; bool moveWordIndex(int relativeMovement); bool canMoveWordIndex(int relativeMovement) const; bool moveCharacterIndex(int relativeMovement); @@ -116,6 +121,8 @@ class Config int getLastPoppedStack() const; int getCurrentWordId() const; void setCurrentWordId(int currentWordId); + void addMissingColumns(); + void addComment(); }; #endif diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index b067b91b46be1810038066e0d0736b6dc0c39d25..696d76c7277431cf644b0012394c735ec6784d8f 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -156,7 +156,10 @@ BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilenam readTSVInput(tsvFilename); if (!has(0,wordIndex,0)) + { + addComment(); addLines(1); + } if (isComment(wordIndex)) moveWordIndex(1); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 9af5d8c920079bac28b843a65fcdf8abce65035f..24612ff72521e73689f00f81fa189e485eb339dc 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -20,6 +20,13 @@ void Config::addLines(unsigned int nbLines) lines.resize(lines.size() + nbLines*getNbColumns()*(nbHypothesesMax+1)); } +void Config::addComment() +{ + lines.resize(lines.size() + getNbColumns()*(nbHypothesesMax+1)); + get(0, getNbLines()-1, 0) = "#"; + getLastNotEmptyHyp(0, getNbLines()-1) = "#"; +} + void Config::resizeLines(unsigned int nbLines) { lines.resize(nbLines*getNbColumns()*(nbHypothesesMax+1)); @@ -342,27 +349,54 @@ bool Config::isComment(std::size_t lineIndex) const return !iter->get().empty() and iter->get()[0] == '#'; } +bool Config::isCommentPredicted(std::size_t lineIndex) const +{ + auto & col0 = getAsFeature(0, lineIndex); + return !util::isEmpty(col0) and col0.get()[0] == '#'; +} + bool Config::isMultiword(std::size_t lineIndex) const { return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('-') != std::string::npos; } +bool Config::isMultiwordPredicted(std::size_t lineIndex) const +{ + return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('-') != std::string::npos; +} + int Config::getMultiwordSize(std::size_t lineIndex) const { auto splited = util::split(getConst(idColName, lineIndex, 0).get(), '-'); return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0])); } +int Config::getMultiwordSizePredicted(std::size_t lineIndex) const +{ + auto splited = util::split(getAsFeature(idColName, lineIndex).get(), '-'); + return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0])); +} + bool Config::isEmptyNode(std::size_t lineIndex) const { return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('.') != std::string::npos; } +bool Config::isEmptyNodePredicted(std::size_t lineIndex) const +{ + return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('.') != std::string::npos; +} + bool Config::isToken(std::size_t lineIndex) const { return !isComment(lineIndex) && !isMultiword(lineIndex) && !isEmptyNode(lineIndex); } +bool Config::isTokenPredicted(std::size_t lineIndex) const +{ + return !isCommentPredicted(lineIndex) && !isMultiwordPredicted(lineIndex) && !isEmptyNodePredicted(lineIndex); +} + bool Config::moveWordIndex(int relativeMovement) { int nbMovements = 0; @@ -504,3 +538,28 @@ void Config::setCurrentWordId(int currentWordId) this->currentWordId = currentWordId; } +void Config::addMissingColumns() +{ + int firstIndex = 0; + for (unsigned int index = 0; index < getNbLines(); index++) + { + if (!isTokenPredicted(index)) + continue; + + if (util::isEmpty(getAsFeature(idColName, index))) + { + int last = 0; + if (index > 0 and isTokenPredicted(index-1)) + last = std::stoi(getAsFeature(idColName, index-1)); + getLastNotEmptyHyp(idColName, index) = std::to_string(last+1); + } + + int curId = std::stoi(getAsFeature(idColName, index)); + if (curId == 1) + firstIndex = index; + + if (util::isEmpty(getAsFeature(headColName, index))) + getLastNotEmptyHyp(headColName, index) = (curId == 1) ? "0" : std::to_string(firstIndex); + } +} +