diff --git a/common/include/util.hpp b/common/include/util.hpp index 478ff9794ee52720634498e1a945a7b4e9657228..bd1a48b19470bc21b24ead6d808fff07e20689ca 100644 --- a/common/include/util.hpp +++ b/common/include/util.hpp @@ -54,6 +54,8 @@ std::string getTime(); long float2long(float f); float long2float(long l); +std::vector<std::vector<std::string>> readTSV(std::string_view tsvFilename); + template <typename T> bool isEmpty(const std::vector<T> & s) { diff --git a/common/src/util.cpp b/common/src/util.cpp index d896173ccb8ab4d5206eb8a7f57fd9a683ed81d1..f5717be08bcd3ecb4aa4cdd39fe7f6b6ad97d7b9 100644 --- a/common/src/util.cpp +++ b/common/src/util.cpp @@ -365,3 +365,45 @@ std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bo return res; } +std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename) +{ + std::vector<std::vector<std::string>> sentences; + + std::FILE * file = std::fopen(tsvFilename.data(), "r"); + + if (not file) + util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename)); + + char lineBuffer[100000]; + bool inputHasBeenRead = false; + + sentences.emplace_back(); + while (!std::feof(file)) + { + if (lineBuffer != std::fgets(lineBuffer, 100000, file)) + break; + + std::string_view line(lineBuffer); + sentences.back().emplace_back(line); + + if (line.size() < 3) + { + if (!inputHasBeenRead) + continue; + + sentences.emplace_back(); + continue; + } + + inputHasBeenRead = true; + } + + if (sentences.back().empty()) + sentences.pop_back(); + + std::fclose(file); + + return sentences; +} + + diff --git a/decoder/src/MacaonDecode.cpp b/decoder/src/MacaonDecode.cpp index e4194e84005028034e55ccb79e11cd9d10d3b1cc..2e4ed9fe811b23d6865878f5f2ce8a8bb0510287 100644 --- a/decoder/src/MacaonDecode.cpp +++ b/decoder/src/MacaonDecode.cpp @@ -96,25 +96,38 @@ int MacaonDecode::main() ReadingMachine machine(machinePath, false); Decoder decoder(machine); - if (inputTXT.empty()) + std::vector<util::utf8string> rawInputs; + if (!inputTXT.empty()) + rawInputs = util::readFileAsUtf8(inputTXT, lineByLine); + + std::vector<std::vector<std::string>> tsv; + if (!inputTSV.empty()) + tsv = util::readTSV(inputTSV); + + std::vector<BaseConfig> configs; + if (lineByLine) { - BaseConfig config(mcd, inputTSV, util::utf8string(), std::vector<int>()); - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); - config.print(stdout, true); + if (rawInputs.size()) + for (unsigned int i = 0; i < rawInputs.size(); i++) + configs.emplace_back(mcd, tsv, rawInputs[i], std::vector<int>{(int)i}); + else + for (unsigned int i = 0; i < tsv.size(); i++) + configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>{(int)i}); } else { - auto inputs = util::readFileAsUtf8(inputTXT, lineByLine); - for (int i = 0; i < (int)inputs.size(); i++) - { - std::vector<int> sentIndexes; - if (inputs.size() > 1) - sentIndexes.emplace_back(i); - BaseConfig config(mcd, inputTSV, inputs[i], sentIndexes); - decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); - config.print(stdout, i == 0); - } + if (rawInputs.size()) + configs.emplace_back(mcd, tsv, rawInputs[0], std::vector<int>()); + else + configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>()); + } + + for (unsigned int i = 0; i < configs.size(); i++) + { + decoder.decode(configs[i], beamSize, beamThreshold, debug, printAdvancement); + configs[i].print(stdout, i == 0); } + } catch(std::exception & e) {util::error(e);} return 0; diff --git a/reading_machine/include/BaseConfig.hpp b/reading_machine/include/BaseConfig.hpp index a93b1804dfe23a470dd1794a8338a390a923a408..167b8fecaac094c4d518b1e640cd10030e21a353 100644 --- a/reading_machine/include/BaseConfig.hpp +++ b/reading_machine/include/BaseConfig.hpp @@ -21,11 +21,11 @@ class BaseConfig : public Config void createColumns(std::string mcd); void readRawInput(std::string_view rawFilename); - void readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes); + void readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes); public : - BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes); + BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes); BaseConfig(const BaseConfig & other); BaseConfig & operator=(const BaseConfig & other) = default; diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index 69905a8d32309f3ac036bc136efe9913fd5983c3..0d373c6e3df2a052ac69a1ded930b3eb6a11ff20 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -43,44 +43,8 @@ void BaseConfig::readRawInput(std::string_view rawFilename) rawInput.replace(util::utf8char("\t"), util::utf8char(" ")); } -void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes) +void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes) { - std::vector<std::vector<std::string>> sentences; - - std::FILE * file = std::fopen(tsvFilename.data(), "r"); - - if (not file) - util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename)); - - char lineBuffer[100000]; - bool inputHasBeenRead = false; - - sentences.emplace_back(); - while (!std::feof(file)) - { - if (lineBuffer != std::fgets(lineBuffer, 100000, file)) - break; - - std::string_view line(lineBuffer); - sentences.back().emplace_back(line); - - if (line.size() < 3) - { - if (!inputHasBeenRead) - continue; - - sentences.emplace_back(); - continue; - } - - inputHasBeenRead = true; - } - - if (sentences.back().empty()) - sentences.pop_back(); - - std::fclose(file); - for (unsigned int i = 0; i < (sentencesIndexes.size() ? sentencesIndexes.size() : sentences.size()); i++) { int targetSentenceIndex = sentencesIndexes.size() ? sentencesIndexes[i] : i; @@ -115,7 +79,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<in if (usualNbCol == -1) usualNbCol = splited.size(); if ((int)splited.size() != usualNbCol) - util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol)); + util::myThrow(fmt::format("in tsv file line {} is invalid, it shoud have {} columns", line, usualNbCol)); // Ignore empty nodes if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos) @@ -182,10 +146,10 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( { } -BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes) +BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes) { - if (tsvFilename.empty() and rawInput.empty()) - util::myThrow("tsvFilename and rawInput can't be both empty"); + if (sentences.empty() and rawInput.empty()) + util::myThrow("sentences and rawInput can't be both empty"); createColumns(mcd); @@ -193,8 +157,8 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util this->rawInput = rawInput; // sentencesIndexes = index of sentences to keep. Empty vector == keep all sentences. - if (not tsvFilename.empty()) - readTSVInput(tsvFilename, sentencesIndexes); + if (not sentences.empty()) + readTSVInput(sentences, sentencesIndexes); if (!has(0,wordIndex,0)) addLines(1); diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index 91d0ea82c96151d9d44af01ce41d89b7abf5bdb2..02d58fd7a7b7398e9d6b6d08d51aeb07ca1948cd 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -117,29 +117,6 @@ inline auto decay(Optimizer &optimizer, double rate) -> void } } -int countNbSentences(std::string_view filename) -{ - std::FILE * file = std::fopen(filename.data(), "r"); - - if (not file) - util::myThrow(fmt::format("Cannot open file '{}'", filename)); - - char lineBuffer[100000]; - int nbSent = 0; - - while (!std::feof(file)) - { - if (lineBuffer != std::fgets(lineBuffer, 100000, file)) - break; - if (std::string(lineBuffer).size() < 3) - nbSent++; - } - - std::fclose(file); - - return nbSent; -} - int MacaonTrain::main() { auto od = getOptionsDescription(); @@ -188,9 +165,13 @@ int MacaonTrain::main() try { - ReadingMachine machine(machinePath.string(), true); + std::vector<std::vector<std::string>> trainTsv, devTsv; + if (!trainTsvFile.empty()) + trainTsv = util::readTSV(trainTsvFile); + if (!devTsvFile.empty()) + devTsv = util::readTSV(devTsvFile); - int nbSentencesTrain = countNbSentences(trainTsvFile); + ReadingMachine machine(machinePath.string(), true); std::vector<util::utf8string> trainRawInputs; if (!trainRawFile.empty()) @@ -199,23 +180,20 @@ int MacaonTrain::main() if (lineByLine) { if (trainRawInputs.size()) - for (int i = 0; i < (int)trainRawInputs.size(); i++) - goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i}); + for (unsigned int i = 0; i < trainRawInputs.size(); i++) + goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[i], std::vector<int>{(int)i}); else - for (int i = 0; i < nbSentencesTrain; i++) - goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i}); + for (unsigned int i = 0; i < trainTsv.size(); i++) + goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>{(int)i}); } else { if (trainRawInputs.size()) - goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>()); + goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[0], std::vector<int>()); else - goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>()); + goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>()); } - - int nbSentencesDev = countNbSentences(devTsvFile); - std::vector<util::utf8string> devRawInputs; if (!devRawFile.empty()) devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine); @@ -223,18 +201,18 @@ int MacaonTrain::main() if (lineByLine) { if (devRawInputs.size()) - for (int i = 0; i < (int)devRawInputs.size(); i++) - devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i}); + for (unsigned int i = 0; i < devRawInputs.size(); i++) + devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i}); else - for (int i = 0; i < nbSentencesDev; i++) - devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i}); + for (unsigned int i = 0; i < devTsv.size(); i++) + devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i}); } else { if (devRawInputs.size()) - devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>()); + devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>()); else - devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>()); + devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>()); } Trainer trainer(machine, batchSize); @@ -322,13 +300,23 @@ int MacaonTrain::main() if (computeDevScore) { std::vector<BaseConfig> devConfigs; - for (int i = 0; i < (int)devRawInputs.size(); i++) - if (devRawInputs.size() == 1) - devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>()); + if (lineByLine) + { + if (devRawInputs.size()) + for (unsigned int i = 0; i < devRawInputs.size(); i++) + devConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i}); + else + for (unsigned int i = 0; i < devTsv.size(); i++) + devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i}); + } + else + { + if (devRawInputs.size()) + devConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>()); else - devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i}); - if (devConfigs.empty()) - devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>()); + devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>()); + } + for (auto & devConfig : devConfigs) decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);