Commit ab0bfc27 authored by Franck Dary's avatar Franck Dary
Browse files

Working lineByLine in tsv mode

parent fffe386e
......@@ -54,6 +54,8 @@ std::string getTime();
long float2long(float f);
float long2float(long l);
std::vector<std::vector<std::string>> readTSV(std::string_view tsvFilename);
template <typename T>
bool isEmpty(const std::vector<T> & s)
{
......
......@@ -365,3 +365,45 @@ std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bo
return res;
}
std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename)
{
std::vector<std::vector<std::string>> sentences;
std::FILE * file = std::fopen(tsvFilename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
char lineBuffer[100000];
bool inputHasBeenRead = false;
sentences.emplace_back();
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
std::string_view line(lineBuffer);
sentences.back().emplace_back(line);
if (line.size() < 3)
{
if (!inputHasBeenRead)
continue;
sentences.emplace_back();
continue;
}
inputHasBeenRead = true;
}
if (sentences.back().empty())
sentences.pop_back();
std::fclose(file);
return sentences;
}
......@@ -96,25 +96,38 @@ int MacaonDecode::main()
ReadingMachine machine(machinePath, false);
Decoder decoder(machine);
if (inputTXT.empty())
std::vector<util::utf8string> rawInputs;
if (!inputTXT.empty())
rawInputs = util::readFileAsUtf8(inputTXT, lineByLine);
std::vector<std::vector<std::string>> tsv;
if (!inputTSV.empty())
tsv = util::readTSV(inputTSV);
std::vector<BaseConfig> configs;
if (lineByLine)
{
BaseConfig config(mcd, inputTSV, util::utf8string(), std::vector<int>());
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, true);
if (rawInputs.size())
for (unsigned int i = 0; i < rawInputs.size(); i++)
configs.emplace_back(mcd, tsv, rawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < tsv.size(); i++)
configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
auto inputs = util::readFileAsUtf8(inputTXT, lineByLine);
for (int i = 0; i < (int)inputs.size(); i++)
{
std::vector<int> sentIndexes;
if (inputs.size() > 1)
sentIndexes.emplace_back(i);
BaseConfig config(mcd, inputTSV, inputs[i], sentIndexes);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, i == 0);
}
if (rawInputs.size())
configs.emplace_back(mcd, tsv, rawInputs[0], std::vector<int>());
else
configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>());
}
for (unsigned int i = 0; i < configs.size(); i++)
{
decoder.decode(configs[i], beamSize, beamThreshold, debug, printAdvancement);
configs[i].print(stdout, i == 0);
}
} catch(std::exception & e) {util::error(e);}
return 0;
......
......@@ -21,11 +21,11 @@ class BaseConfig : public Config
void createColumns(std::string mcd);
void readRawInput(std::string_view rawFilename);
void readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes);
void readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes);
public :
BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
BaseConfig(const BaseConfig & other);
BaseConfig & operator=(const BaseConfig & other) = default;
......
......@@ -43,44 +43,8 @@ void BaseConfig::readRawInput(std::string_view rawFilename)
rawInput.replace(util::utf8char("\t"), util::utf8char(" "));
}
void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes)
void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes)
{
std::vector<std::vector<std::string>> sentences;
std::FILE * file = std::fopen(tsvFilename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
char lineBuffer[100000];
bool inputHasBeenRead = false;
sentences.emplace_back();
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
std::string_view line(lineBuffer);
sentences.back().emplace_back(line);
if (line.size() < 3)
{
if (!inputHasBeenRead)
continue;
sentences.emplace_back();
continue;
}
inputHasBeenRead = true;
}
if (sentences.back().empty())
sentences.pop_back();
std::fclose(file);
for (unsigned int i = 0; i < (sentencesIndexes.size() ? sentencesIndexes.size() : sentences.size()); i++)
{
int targetSentenceIndex = sentencesIndexes.size() ? sentencesIndexes[i] : i;
......@@ -115,7 +79,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<in
if (usualNbCol == -1)
usualNbCol = splited.size();
if ((int)splited.size() != usualNbCol)
util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol));
util::myThrow(fmt::format("in tsv file line {} is invalid, it shoud have {} columns", line, usualNbCol));
// Ignore empty nodes
if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos)
......@@ -182,10 +146,10 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
{
}
BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
{
if (tsvFilename.empty() and rawInput.empty())
util::myThrow("tsvFilename and rawInput can't be both empty");
if (sentences.empty() and rawInput.empty())
util::myThrow("sentences and rawInput can't be both empty");
createColumns(mcd);
......@@ -193,8 +157,8 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util
this->rawInput = rawInput;
// sentencesIndexes = index of sentences to keep. Empty vector == keep all sentences.
if (not tsvFilename.empty())
readTSVInput(tsvFilename, sentencesIndexes);
if (not sentences.empty())
readTSVInput(sentences, sentencesIndexes);
if (!has(0,wordIndex,0))
addLines(1);
......
......@@ -117,29 +117,6 @@ inline auto decay(Optimizer &optimizer, double rate) -> void
}
}
int countNbSentences(std::string_view filename)
{
std::FILE * file = std::fopen(filename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", filename));
char lineBuffer[100000];
int nbSent = 0;
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
if (std::string(lineBuffer).size() < 3)
nbSent++;
}
std::fclose(file);
return nbSent;
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......@@ -188,9 +165,13 @@ int MacaonTrain::main()
try
{
ReadingMachine machine(machinePath.string(), true);
std::vector<std::vector<std::string>> trainTsv, devTsv;
if (!trainTsvFile.empty())
trainTsv = util::readTSV(trainTsvFile);
if (!devTsvFile.empty())
devTsv = util::readTSV(devTsvFile);
int nbSentencesTrain = countNbSentences(trainTsvFile);
ReadingMachine machine(machinePath.string(), true);
std::vector<util::utf8string> trainRawInputs;
if (!trainRawFile.empty())
......@@ -199,23 +180,20 @@ int MacaonTrain::main()
if (lineByLine)
{
if (trainRawInputs.size())
for (int i = 0; i < (int)trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i});
for (unsigned int i = 0; i < trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[i], std::vector<int>{(int)i});
else
for (int i = 0; i < nbSentencesTrain; i++)
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i});
for (unsigned int i = 0; i < trainTsv.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (trainRawInputs.size())
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>());
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[0], std::vector<int>());
else
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>());
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>());
}
int nbSentencesDev = countNbSentences(devTsvFile);
std::vector<util::utf8string> devRawInputs;
if (!devRawFile.empty())
devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine);
......@@ -223,18 +201,18 @@ int MacaonTrain::main()
if (lineByLine)
{
if (devRawInputs.size())
for (int i = 0; i < (int)devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i});
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (int i = 0; i < nbSentencesDev; i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i});
for (unsigned int i = 0; i < devTsv.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>());
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
else
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>());
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
Trainer trainer(machine, batchSize);
......@@ -322,13 +300,23 @@ int MacaonTrain::main()
if (computeDevScore)
{
std::vector<BaseConfig> devConfigs;
for (int i = 0; i < (int)devRawInputs.size(); i++)
if (devRawInputs.size() == 1)
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>());
if (lineByLine)
{
if (devRawInputs.size())
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < devTsv.size(); i++)
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
else
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i});
if (devConfigs.empty())
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>());
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
for (auto & devConfig : devConfigs)
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment