Skip to content
Snippets Groups Projects
Commit ffdff2ff authored by Franck Dary's avatar Franck Dary
Browse files

Added function Config::addMissingColumns that correct the current config to...

Added function Config::addMissingColumns that correct the current config to make it compatible with conll eval script
parent 7c5c406e
No related branches found
No related tags found
No related merge requests found
...@@ -88,6 +88,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -88,6 +88,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
if (debug) if (debug)
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
} }
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
try {config.addMissingColumns();}
catch (std::exception & e) {util::myThrow(e.what());}
} }
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
...@@ -145,6 +149,7 @@ std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std: ...@@ -145,6 +149,7 @@ std::vector<std::pair<float,std::string>> Decoder::getScores(const std::set<std:
std::vector<std::pair<float, std::string>> scores; std::vector<std::pair<float, std::string>> scores;
for (auto & colName : colNames) for (auto & colName : colNames)
if (colName != Config::idColName)
scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName))); scores.emplace_back(std::make_pair((this->*metric2score)(getMetricOfColName(colName)), getMetricOfColName(colName)));
return scores; return scores;
...@@ -160,6 +165,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const ...@@ -160,6 +165,8 @@ std::string Decoder::getMetricOfColName(const std::string & colName) const
return "Sentences"; return "Sentences";
if (colName == "FEATS") if (colName == "FEATS")
return "UFeats"; return "UFeats";
if (colName == "FORM")
return "Words";
return colName; return colName;
} }
......
...@@ -93,10 +93,15 @@ class Config ...@@ -93,10 +93,15 @@ class Config
void addToStack(std::size_t index); void addToStack(std::size_t index);
void popStack(); void popStack();
bool isComment(std::size_t lineIndex) const; bool isComment(std::size_t lineIndex) const;
bool isCommentPredicted(std::size_t lineIndex) const;
bool isMultiword(std::size_t lineIndex) const; bool isMultiword(std::size_t lineIndex) const;
bool isMultiwordPredicted(std::size_t lineIndex) const;
int getMultiwordSize(std::size_t lineIndex) const; int getMultiwordSize(std::size_t lineIndex) const;
int getMultiwordSizePredicted(std::size_t lineIndex) const;
bool isEmptyNode(std::size_t lineIndex) const; bool isEmptyNode(std::size_t lineIndex) const;
bool isEmptyNodePredicted(std::size_t lineIndex) const;
bool isToken(std::size_t lineIndex) const; bool isToken(std::size_t lineIndex) const;
bool isTokenPredicted(std::size_t lineIndex) const;
bool moveWordIndex(int relativeMovement); bool moveWordIndex(int relativeMovement);
bool canMoveWordIndex(int relativeMovement) const; bool canMoveWordIndex(int relativeMovement) const;
bool moveCharacterIndex(int relativeMovement); bool moveCharacterIndex(int relativeMovement);
...@@ -116,6 +121,8 @@ class Config ...@@ -116,6 +121,8 @@ class Config
int getLastPoppedStack() const; int getLastPoppedStack() const;
int getCurrentWordId() const; int getCurrentWordId() const;
void setCurrentWordId(int currentWordId); void setCurrentWordId(int currentWordId);
void addMissingColumns();
void addComment();
}; };
#endif #endif
...@@ -156,7 +156,10 @@ BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilenam ...@@ -156,7 +156,10 @@ BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilenam
readTSVInput(tsvFilename); readTSVInput(tsvFilename);
if (!has(0,wordIndex,0)) if (!has(0,wordIndex,0))
{
addComment();
addLines(1); addLines(1);
}
if (isComment(wordIndex)) if (isComment(wordIndex))
moveWordIndex(1); moveWordIndex(1);
......
...@@ -20,6 +20,13 @@ void Config::addLines(unsigned int nbLines) ...@@ -20,6 +20,13 @@ void Config::addLines(unsigned int nbLines)
lines.resize(lines.size() + nbLines*getNbColumns()*(nbHypothesesMax+1)); lines.resize(lines.size() + nbLines*getNbColumns()*(nbHypothesesMax+1));
} }
void Config::addComment()
{
lines.resize(lines.size() + getNbColumns()*(nbHypothesesMax+1));
get(0, getNbLines()-1, 0) = "#";
getLastNotEmptyHyp(0, getNbLines()-1) = "#";
}
void Config::resizeLines(unsigned int nbLines) void Config::resizeLines(unsigned int nbLines)
{ {
lines.resize(nbLines*getNbColumns()*(nbHypothesesMax+1)); lines.resize(nbLines*getNbColumns()*(nbHypothesesMax+1));
...@@ -342,27 +349,54 @@ bool Config::isComment(std::size_t lineIndex) const ...@@ -342,27 +349,54 @@ bool Config::isComment(std::size_t lineIndex) const
return !iter->get().empty() and iter->get()[0] == '#'; return !iter->get().empty() and iter->get()[0] == '#';
} }
bool Config::isCommentPredicted(std::size_t lineIndex) const
{
auto & col0 = getAsFeature(0, lineIndex);
return !util::isEmpty(col0) and col0.get()[0] == '#';
}
bool Config::isMultiword(std::size_t lineIndex) const bool Config::isMultiword(std::size_t lineIndex) const
{ {
return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('-') != std::string::npos; return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('-') != std::string::npos;
} }
bool Config::isMultiwordPredicted(std::size_t lineIndex) const
{
return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('-') != std::string::npos;
}
int Config::getMultiwordSize(std::size_t lineIndex) const int Config::getMultiwordSize(std::size_t lineIndex) const
{ {
auto splited = util::split(getConst(idColName, lineIndex, 0).get(), '-'); auto splited = util::split(getConst(idColName, lineIndex, 0).get(), '-');
return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0])); return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0]));
} }
int Config::getMultiwordSizePredicted(std::size_t lineIndex) const
{
auto splited = util::split(getAsFeature(idColName, lineIndex).get(), '-');
return std::stoi(std::string(splited[1])) - std::stoi(std::string(splited[0]));
}
bool Config::isEmptyNode(std::size_t lineIndex) const bool Config::isEmptyNode(std::size_t lineIndex) const
{ {
return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('.') != std::string::npos; return hasColIndex(idColName) && getConst(idColName, lineIndex, 0).get().find('.') != std::string::npos;
} }
bool Config::isEmptyNodePredicted(std::size_t lineIndex) const
{
return hasColIndex(idColName) && getAsFeature(idColName, lineIndex).get().find('.') != std::string::npos;
}
bool Config::isToken(std::size_t lineIndex) const bool Config::isToken(std::size_t lineIndex) const
{ {
return !isComment(lineIndex) && !isMultiword(lineIndex) && !isEmptyNode(lineIndex); return !isComment(lineIndex) && !isMultiword(lineIndex) && !isEmptyNode(lineIndex);
} }
bool Config::isTokenPredicted(std::size_t lineIndex) const
{
return !isCommentPredicted(lineIndex) && !isMultiwordPredicted(lineIndex) && !isEmptyNodePredicted(lineIndex);
}
bool Config::moveWordIndex(int relativeMovement) bool Config::moveWordIndex(int relativeMovement)
{ {
int nbMovements = 0; int nbMovements = 0;
...@@ -504,3 +538,28 @@ void Config::setCurrentWordId(int currentWordId) ...@@ -504,3 +538,28 @@ void Config::setCurrentWordId(int currentWordId)
this->currentWordId = currentWordId; this->currentWordId = currentWordId;
} }
void Config::addMissingColumns()
{
int firstIndex = 0;
for (unsigned int index = 0; index < getNbLines(); index++)
{
if (!isTokenPredicted(index))
continue;
if (util::isEmpty(getAsFeature(idColName, index)))
{
int last = 0;
if (index > 0 and isTokenPredicted(index-1))
last = std::stoi(getAsFeature(idColName, index-1));
getLastNotEmptyHyp(idColName, index) = std::to_string(last+1);
}
int curId = std::stoi(getAsFeature(idColName, index));
if (curId == 1)
firstIndex = index;
if (util::isEmpty(getAsFeature(headColName, index)))
getLastNotEmptyHyp(headColName, index) = (curId == 1) ? "0" : std::to_string(firstIndex);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment