diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 7d62241cb06f84f57c4f8df2bf5551e484024b00..a3ddb73743fb59a97ebf12890d99fe5c6da2e14a 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -14,14 +14,14 @@ po::options_description MacaonDecode::getOptionsDescription() ("inputTSV", po::value<std::string>(), "File containing the text to decode, TSV file") ("inputTXT", po::value<std::string>(), - "File containing the text to decode, raw text file") - ("mcd", po::value<std::string>()->required(), - "Multi Column Description file that describes the input/output format"); + "File containing the text to decode, raw text file"); po::options_description opt("Optional"); opt.add_options() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") + ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), + "Comma separated column names that describes the input/output format") ("beamSize", po::value<int>()->default_value(1), "Size of the beam during beam search") ("beamThreshold", po::value<float>()->default_value(0.1), @@ -72,7 +72,7 @@ int MacaonDecode::main() auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, "")); auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : ""; auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : ""; - auto mcdFile = variables["mcd"].as<std::string>(); + auto mcd = variables["mcd"].as<std::string>(); bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; auto beamSize = variables["beamSize"].as<int>(); @@ -90,7 +90,7 @@ int MacaonDecode::main() ReadingMachine machine(machinePath, modelPaths); Decoder decoder(machine); - BaseConfig config(mcdFile, inputTSV, inputTXT); + BaseConfig config(mcd, inputTSV, inputTXT); decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index 0b009cdbbe83da767b1dd57038f1c7cfdf2a3905..85d4487a344c33e187885dc1579e5546855eef04 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -19,13 +19,13 @@ class BaseConfig : public Config private : - void readMCD(std::string_view mcdFilename); + void createColumns(std::string mcd); void readRawInput(std::string_view rawFilename); void readTSVInput(std::string_view tsvFilename); public : - BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename); + BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename); BaseConfig(const BaseConfig & other); BaseConfig & operator=(const BaseConfig & other) = default; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 690f937df7076699c4a0ff6b579743b021b4824d..faa594eb8e53fa48cf2a7ba421d1ca6b084bc6e9 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -61,6 +61,7 @@ class Config std::vector<Transition *> appliableSplitTransitions; std::vector<int> appliableTransitions; std::shared_ptr<Strategy> strategy; + std::string mcd; protected : diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 296a58f683b62852d626e113d469765de08e4b62..d365cf63af4fc362a009b4cde8f7956e0c72fd7e 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -1,29 +1,24 @@ #include "BaseConfig.hpp" #include "util.hpp" -void BaseConfig::readMCD(std::string_view mcdFilename) +void BaseConfig::createColumns(std::string mcd) { - if (!colIndex2Name.empty()) - util::myThrow("a mcd has already been read for this BaseConfig"); + this->mcd = mcd; - std::FILE * file = std::fopen(mcdFilename.data(), "r"); + colIndex2Name.clear(); + colName2Index.clear(); - if (not file) - util::myThrow(fmt::format("Cannot open file '{}'", mcdFilename)); - - char lineBuffer[1024]; - while (std::fscanf(file, "%1023[^\n]\n", lineBuffer) == 1) + auto splited = util::split(mcd, ','); + for (auto & colName : splited) { - colIndex2Name.emplace_back(lineBuffer); - colName2Index.emplace(lineBuffer, colIndex2Name.size()-1); + colIndex2Name.emplace_back(colName); + colName2Index.emplace(colName, colIndex2Name.size()-1); } - std::fclose(file); - for (auto & column : extraColumns) { if (colName2Index.count(column)) - util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, column)); + util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcd, column)); colIndex2Name.emplace_back(column); colName2Index.emplace(column, colIndex2Name.size()-1); } @@ -111,6 +106,12 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) if (line[0] == '#') { + if (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)#(?:(?:\\s|\\t)*)global.columns(?:(?:\\s|\\t)*)=(?:(?:\\s|\\t)*)(.+)"), line, [this](const auto & sm) + { + createColumns(util::join(",", util::split(util::strip(sm.str(1)), ' '))); + })) + continue; + addLines(1); get(EOSColName, getNbLines()-1, 0) = EOSSymbol0; get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0; @@ -158,14 +159,12 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( { } -BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) +BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename) { if (tsvFilename.empty() and rawFilename.empty()) util::myThrow("tsvFilename and rawFilenames can't be both empty"); - if (mcdFilename.empty()) - util::myThrow("mcdFilename can't be empty"); - readMCD(mcdFilename); + createColumns(mcd); if (not rawFilename.empty()) readRawInput(rawFilename); @@ -192,7 +191,7 @@ std::size_t BaseConfig::getColIndex(const std::string & colName) const { auto it = colName2Index.find(colName); if (it == colName2Index.end()) - util::myThrow(fmt::format("unknown column name '{}'", colName)); + util::myThrow(fmt::format("unknown column name '{}', mcd = '{}'", colName, mcd)); return it->second; } diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index 81fbc58b66c6916bae6b7b0bce7542fb84102e24..4e320b9d7b80304bf51358452b668d692895813e 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -19,7 +19,8 @@ Config::Config(const Config & other) this->state = other.state; this->history = other.history; this->stack = other.stack; - this->extraColumns = this->extraColumns; + this->extraColumns = other.extraColumns; + this->mcd = other.mcd; } std::size_t Config::getIndexOfLine(int lineIndex) const @@ -106,6 +107,8 @@ void Config::print(FILE * dest) const currentSequenceComments.clear(); }; + fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ','))); + for (unsigned int line = 0; line < getNbLines(); line++) { if (isComment(getFirstLineIndex()+line)) diff --git a/reading_machine/src/SubConfig.cpp b/reading_machine/src/SubConfig.cpp index eef8ae778a0fa38b3f89b48e29cfe12c7fa35583..7b7b65ead097e4bb1d0d33b0a5bbd9d8b55831a9 100644 --- a/reading_machine/src/SubConfig.cpp +++ b/reading_machine/src/SubConfig.cpp @@ -15,6 +15,7 @@ SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : model(model), s currentWordId = model.currentWordId; appliableSplitTransitions = model.appliableSplitTransitions; appliableTransitions = model.appliableTransitions; + mcd = model.mcd; if (model.strategy.get() != nullptr) strategy.reset(new Strategy(model.getStrategy())); update(); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 85c1def8a041efe0ab2a8d4bd009b1aa4504d703..ea0f6f52eced990cc0e05f893c44302e20486bbb 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -13,8 +13,6 @@ po::options_description MacaonTrain::getOptionsDescription() req.add_options() ("model", po::value<std::string>()->required(), "Directory containing the machine file to train") - ("mcd", po::value<std::string>()->required(), - "Multi Column Description file that describes the input format") ("trainTSV", po::value<std::string>()->required(), "TSV file of the training corpus, in CONLLU format"); @@ -23,6 +21,8 @@ po::options_description MacaonTrain::getOptionsDescription() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") ("devScore", "Compute score on dev instead of loss (slower)") + ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), + "Comma separated column names that describes the input/output format") ("trainTXT", po::value<std::string>()->default_value(""), "Raw text file of the training corpus") ("devTSV", po::value<std::string>()->default_value(""), @@ -97,7 +97,7 @@ int MacaonTrain::main() std::filesystem::path modelPath(variables["model"].as<std::string>()); auto machinePath = modelPath / "machine.rm"; - auto mcdFile = variables["mcd"].as<std::string>(); + auto mcd = variables["mcd"].as<std::string>(); auto trainTsvFile = variables["trainTSV"].as<std::string>(); auto trainRawFile = variables["trainTXT"].as<std::string>(); auto devTsvFile = variables["devTSV"].as<std::string>(); @@ -132,8 +132,8 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string()); - BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); - BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); + BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); + BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); Trainer trainer(machine, batchSize); Decoder decoder(machine); @@ -230,7 +230,7 @@ int MacaonTrain::main() std::vector<std::pair<float,std::string>> devScores; if (computeDevScore) { - BaseConfig devConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); + BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.evaluate(devConfig, modelPath, devTsvFile); devScores = decoder.getF1Scores(machine.getPredicted());