diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index f73bba46c5dfc571df418ce2ee201bec06f544b2..aef9932fec06c2ce2a4663f4c8ed00726f362b3f 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -56,6 +56,7 @@ class Config String state{"NONE"}; boost::circular_buffer<String> history{10}; boost::circular_buffer<std::size_t> stack{50}; + std::vector<std::string> extraColumns{isMultiColName, childsColName, sentIdColName, EOSColName}; protected : @@ -145,6 +146,7 @@ class Config void addComment(); void setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions); const std::vector<Transition *> & getAppliableSplitTransitions() const; + bool isExtraColumn(const std::string & colName) const; }; #endif diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 6632c750ed2a3ab4c93af6c5bc1920c985390c88..1580c9a759e6b3293a345614d04026a023906f48 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -28,25 +28,13 @@ void BaseConfig::readMCD(std::string_view mcdFilename) std::fclose(file); - if (colName2Index.count(isMultiColName)) - util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, isMultiColName)); - colIndex2Name.emplace_back(isMultiColName); - colName2Index.emplace(isMultiColName, colIndex2Name.size()-1); - - if (colName2Index.count(childsColName)) - util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, childsColName)); - colIndex2Name.emplace_back(childsColName); - colName2Index.emplace(childsColName, colIndex2Name.size()-1); - - if (colName2Index.count(sentIdColName)) - util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, sentIdColName)); - colIndex2Name.emplace_back(sentIdColName); - colName2Index.emplace(sentIdColName, colIndex2Name.size()-1); - - if (colName2Index.count(EOSColName)) - util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName)); - colIndex2Name.emplace_back(EOSColName); - colName2Index.emplace(EOSColName, colIndex2Name.size()-1); + for (auto & column : extraColumns) + { + if (colName2Index.count(column)) + util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, column)); + colIndex2Name.emplace_back(column); + colName2Index.emplace(column, colIndex2Name.size()-1); + } } void BaseConfig::readRawInput(std::string_view rawFilename) diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 9526a870a66165134e16f912070473609cb9e0e8..1ea995cb7e4bd33a176665e02eb58621470ae488 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -98,7 +98,7 @@ void Config::print(FILE * dest) const } for (unsigned int i = 0; i < getNbColumns()-1; i++) { - if (getColName(i) == isMultiColName or getColName(i) == childsColName) + if (isExtraColumn(getColName(i)) and getColName(i) != EOSColName) { if (i == getNbColumns()-2) currentSequence.back().back() = '\n'; @@ -146,7 +146,7 @@ void Config::printForDebug(FILE * dest) const toPrint.back().emplace_back(""); for (unsigned int i = 0; i < getNbColumns(); i++) { - if (getColName(i) == isMultiColName or getColName(i) == childsColName) + if (isExtraColumn(getColName(i)) and getColName(i) != EOSColName) continue; toPrint.back().emplace_back(getColName(i)); } @@ -159,7 +159,7 @@ void Config::printForDebug(FILE * dest) const toPrint.back().emplace_back(line == (int)wordIndex ? "=>" : ""); for (unsigned int i = 0; i < getNbColumns(); i++) { - if (getColName(i) == isMultiColName or getColName(i) == childsColName) + if (isExtraColumn(getColName(i)) and getColName(i) != EOSColName) continue; std::string colContent = has(i,line,0) ? getAsFeature(i, line).get() : "?"; std::string toPrintCol = colContent; @@ -563,6 +563,10 @@ void Config::addPredicted(const std::set<std::string> & predicted) util::myThrow(fmt::format("unknown column '{}'", col)); this->predicted.insert(col); } + + for (auto & col : extraColumns) + if (col != EOSColName) + this->predicted.insert(col); } bool Config::isPredicted(const std::string & colName) const @@ -674,3 +678,11 @@ Config::Object Config::str2object(const std::string & s) return Object::Buffer; } +bool Config::isExtraColumn(const std::string & colName) const +{ + for (auto & extraCol : extraColumns) + if (extraCol == colName) + return true; + return false; +} +