diff --git a/common/include/util.hpp b/common/include/util.hpp index 1a71486b5dcb0fe0c8c037e1b60ae3151e286d26..478ff9794ee52720634498e1a945a7b4e9657228 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -106,6 +106,8 @@ utf8string upper(const utf8string & s); void upper(utf8char & c); +std::vector<utf8string> readFileAsUtf8(std::string_view filename, bool lineByLine); + template <typename T> std::string join(const std::string & delim, const std::vector<T> elems) { diff --git a/common/src/util.cpp b/common/src/util.cpp index b084f81748bce7e5cb252a6e1b25ff1fba8d9312..d896173ccb8ab4d5206eb8a7f57fd9a683ed81d1 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -318,3 +318,50 @@ void util::upper(utf8char & c) c = it->second; } +std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bool lineByLine) +{ + std::vector<utf8string> res; + std::FILE * file = std::fopen(filename.data(), "r"); + + if (not file) + util::myThrow(fmt::format("Cannot open file '{}'", filename)); + + std::string lineTemp; + + if (!lineByLine) + { + while (not std::feof(file)) + lineTemp.push_back(std::fgetc(file)); + + + auto line = util::splitAsUtf8(lineTemp); + line.replace(util::utf8char("\n"), util::utf8char(" ")); + line.replace(util::utf8char("\t"), util::utf8char(" ")); + + res.emplace_back(line); + } + else + { + while (not std::feof(file)) + { + lineTemp.clear(); + while (not std::feof(file)) + { + lineTemp.push_back(std::fgetc(file)); + if (lineTemp.back() == '\n') + break; + } + + auto line = util::splitAsUtf8(lineTemp); + line.replace(util::utf8char("\n"), util::utf8char(" ")); + line.replace(util::utf8char("\t"), util::utf8char(" ")); + if (!line.empty()) + res.emplace_back(line); + } + } + + std::fclose(file); + + return res; +} + diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index 65b49d0b606972bd84f79f5614651a7356770c39..0bbf5aab3a6a41bf0835581accf861e409341d29 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -22,6 +22,7 @@ po::options_description MacaonDecode::getOptionsDescription() ("debug,d", "Print debuging infos on stderr") ("silent", "Don't print speed and progress") ("reloadEmbeddings", "Reload pretrained embeddings") + ("lineByLine", "Treat the TXT input as being one different text per line.") ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), "Comma separated column names that describes the input/output format") ("beamSize", po::value<int>()->default_value(1), @@ -78,6 +79,7 @@ int MacaonDecode::main() bool debug = variables.count("debug") == 0 ? false : true; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true; + bool lineByLine = variables.count("lineByLine") == 0 ? false : true; auto beamSize = variables["beamSize"].as<int>(); auto beamThreshold = variables["beamThreshold"].as<float>(); @@ -94,11 +96,22 @@ int MacaonDecode::main() ReadingMachine machine(machinePath, false); Decoder decoder(machine); - BaseConfig config(mcd, inputTSV, inputTXT); - - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); - - config.print(stdout); + if (inputTXT.empty()) + { + BaseConfig config(mcd, inputTSV, util::utf8string()); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); + config.print(stdout, true); + } + else + { + auto inputs = util::readFileAsUtf8(inputTXT, lineByLine); + for (unsigned int i = 0; i < inputs.size(); i++) + { + BaseConfig config(mcd, inputTSV, inputs[i]); + decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); + config.print(stdout, i == 0); + } + } } catch(std::exception & e) {util::error(e);} return 0; diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index 85d4487a344c33e187885dc1579e5546855eef04..d77d0a0720ed14439412258a6fa9813514aac173 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -25,7 +25,7 @@ class BaseConfig : public Config public : - BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename); + BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename); BaseConfig(const BaseConfig & other); BaseConfig & operator=(const BaseConfig & other) = default; diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp index 42b55af78b19e03d6ccada9e2faf017c2d1d444a..68408c8ab2a7ab060ea1072c258aea74d15d3a0b 100644 --- a/reading_machine/include/Config.hpp +++ b/reading_machine/include/Config.hpp @@ -107,7 +107,7 @@ class Config public : - void print(FILE * dest) const; + void print(FILE * dest, bool printHeader = true) const; void printForDebug(FILE * dest) const; bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const; String & get(const std::string & colName, int lineIndex, int hypothesisIndex); diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index abb229e42701acb7594b7dc776d227dcdbdeb900..298d4537b064b4126d7f67e04129abceeaf6529c 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -161,15 +161,15 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( { } -BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename) +BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput) { - if (tsvFilename.empty() and rawFilename.empty()) - util::myThrow("tsvFilename and rawFilenames can't be both empty"); + if (tsvFilename.empty() and rawInput.empty()) + util::myThrow("tsvFilename and rawInput can't be both empty"); createColumns(mcd); - if (not rawFilename.empty()) - readRawInput(rawFilename); + if (not rawInput.empty()) + this->rawInput = rawInput; if (not tsvFilename.empty()) readTSVInput(tsvFilename); diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index ffe2f3e7a1c2a13ee04133f5d8963b9a49765621..ddc5477fb58fce6e3e465e5549ea7f31ea285bbc 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -78,7 +78,7 @@ std::size_t Config::getNbLines() const return lines.size() / getIndexOfCol(getNbColumns()); } -void Config::print(FILE * dest) const +void Config::print(FILE * dest, bool printHeader) const { std::vector<std::string> currentSequence; std::vector<std::string> currentSequenceComments; @@ -100,7 +100,8 @@ void Config::print(FILE * dest) const currentSequenceComments.clear(); }; - fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ','))); + if (printHeader) + fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ','))); for (unsigned int line = 0; line < getNbLines(); line++) { diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 6d7266e18228d3dd130bd668bdec2972c0dabbb9..a0492974a6370f4f48a1d105664e1bcad129154b 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -165,8 +165,20 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string(), true); - BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); - BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); + util::utf8string trainRawInput; + if (!trainRawFile.empty()) + { + auto input = util::readFileAsUtf8(trainRawFile, false); + trainRawInput = input[0]; + } + BaseConfig goldConfig(mcd, trainTsvFile, trainRawInput); + util::utf8string devRawInput; + if (!devRawFile.empty()) + { + auto input = util::readFileAsUtf8(devRawFile, false); + devRawInput = input[0]; + } + BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput); Trainer trainer(machine, batchSize); Decoder decoder(machine); @@ -251,7 +263,7 @@ int MacaonTrain::main() std::vector<std::pair<float,std::string>> devScores; if (computeDevScore) { - BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); + BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.evaluate(devConfig, modelPath, devTsvFile, machine.getPredicted()); devScores = decoder.getF1Scores(machine.getPredicted());