Skip to content
Snippets Groups Projects
Commit fffe386e authored by Franck Dary's avatar Franck Dary
Browse files

Added lineByLine in train in tsv mode

parent 2ebf30b6
No related branches found
No related tags found
No related merge requests found
......@@ -117,6 +117,29 @@ inline auto decay(Optimizer &optimizer, double rate) -> void
}
}
int countNbSentences(std::string_view filename)
{
std::FILE * file = std::fopen(filename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", filename));
char lineBuffer[100000];
int nbSent = 0;
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
if (std::string(lineBuffer).size() < 3)
nbSent++;
}
std::fclose(file);
return nbSent;
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......@@ -167,29 +190,52 @@ int MacaonTrain::main()
ReadingMachine machine(machinePath.string(), true);
int nbSentencesTrain = countNbSentences(trainTsvFile);
std::vector<util::utf8string> trainRawInputs;
if (!trainRawFile.empty())
trainRawInputs = util::readFileAsUtf8(trainRawFile, lineByLine);
std::vector<BaseConfig> goldConfigs;
for (int i = 0; i < (int)trainRawInputs.size(); i++)
if (trainRawInputs.size() == 1)
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>());
if (lineByLine)
{
if (trainRawInputs.size())
for (int i = 0; i < (int)trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i});
else
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i});
if (goldConfigs.empty())
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>());
for (int i = 0; i < nbSentencesTrain; i++)
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i});
}
else
{
if (trainRawInputs.size())
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>());
else
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>());
}
int nbSentencesDev = countNbSentences(devTsvFile);
std::vector<util::utf8string> devRawInputs;
if (!devRawFile.empty())
devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine);
std::vector<BaseConfig> devGoldConfigs;
for (int i = 0; i < (int)devRawInputs.size(); i++)
if (devRawInputs.size() == 1)
devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>());
if (lineByLine)
{
if (devRawInputs.size())
for (int i = 0; i < (int)devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i});
else
for (int i = 0; i < nbSentencesDev; i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i});
}
else
{
if (devRawInputs.size())
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>());
else
devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i});
if (devGoldConfigs.empty())
devGoldConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>());
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>());
}
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment