From fffe386ee26133da469fd33a2606beeea7b52b8a Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 16 Feb 2021 22:27:09 +0100 Subject: [PATCH] Added lineByLine in train in tsv mode --- trainer/src/MacaonTrain.cpp | 70 ++++++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 12 deletions(-) diff --git a/trainer/src/MacaonTrain.cpp b/trainer/src/MacaonTrain.cpp index dba21cd..91d0ea8 100644 --- a/trainer/src/MacaonTrain.cpp +++ b/trainer/src/MacaonTrain.cpp @@ -117,6 +117,29 @@ inline auto decay(Optimizer &optimizer, double rate) -> void } } +int countNbSentences(std::string_view filename) +{ + std::FILE * file = std::fopen(filename.data(), "r"); + + if (not file) + util::myThrow(fmt::format("Cannot open file '{}'", filename)); + + char lineBuffer[100000]; + int nbSent = 0; + + while (!std::feof(file)) + { + if (lineBuffer != std::fgets(lineBuffer, 100000, file)) + break; + if (std::string(lineBuffer).size() < 3) + nbSent++; + } + + std::fclose(file); + + return nbSent; +} + int MacaonTrain::main() { auto od = getOptionsDescription(); @@ -167,29 +190,52 @@ int MacaonTrain::main() ReadingMachine machine(machinePath.string(), true); + int nbSentencesTrain = countNbSentences(trainTsvFile); + std::vector<util::utf8string> trainRawInputs; if (!trainRawFile.empty()) trainRawInputs = util::readFileAsUtf8(trainRawFile, lineByLine); std::vector<BaseConfig> goldConfigs; - for (int i = 0; i < (int)trainRawInputs.size(); i++) - if (trainRawInputs.size() == 1) - goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>()); + if (lineByLine) + { + if (trainRawInputs.size()) + for (int i = 0; i < (int)trainRawInputs.size(); i++) + goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i}); else - goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i}); - if (goldConfigs.empty()) - goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>()); + for (int i = 0; i < nbSentencesTrain; i++) + goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i}); + } + else + { + if (trainRawInputs.size()) + goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>()); + else + goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>()); + } + + + int nbSentencesDev = countNbSentences(devTsvFile); std::vector<util::utf8string> devRawInputs; if (!devRawFile.empty()) devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine); std::vector<BaseConfig> devGoldConfigs; - for (int i = 0; i < (int)devRawInputs.size(); i++) - if (devRawInputs.size() == 1) - devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>()); + if (lineByLine) + { + if (devRawInputs.size()) + for (int i = 0; i < (int)devRawInputs.size(); i++) + devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i}); + else + for (int i = 0; i < nbSentencesDev; i++) + devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i}); + } + else + { + if (devRawInputs.size()) + devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>()); else - devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i}); - if (devGoldConfigs.empty()) - devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>()); + devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>()); + } Trainer trainer(machine, batchSize); Decoder decoder(machine); -- GitLab