Skip to content
Snippets Groups Projects
Commit ab0bfc27 authored by Franck Dary's avatar Franck Dary
Browse files

Working lineByLine in tsv mode

parent fffe386e
No related branches found
No related tags found
No related merge requests found
......@@ -54,6 +54,8 @@ std::string getTime();
long float2long(float f);
float long2float(long l);
std::vector<std::vector<std::string>> readTSV(std::string_view tsvFilename);
template <typename T>
bool isEmpty(const std::vector<T> & s)
{
......
......@@ -365,3 +365,45 @@ std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bo
return res;
}
std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename)
{
std::vector<std::vector<std::string>> sentences;
std::FILE * file = std::fopen(tsvFilename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
char lineBuffer[100000];
bool inputHasBeenRead = false;
sentences.emplace_back();
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
std::string_view line(lineBuffer);
sentences.back().emplace_back(line);
if (line.size() < 3)
{
if (!inputHasBeenRead)
continue;
sentences.emplace_back();
continue;
}
inputHasBeenRead = true;
}
if (sentences.back().empty())
sentences.pop_back();
std::fclose(file);
return sentences;
}
......@@ -96,25 +96,38 @@ int MacaonDecode::main()
ReadingMachine machine(machinePath, false);
Decoder decoder(machine);
if (inputTXT.empty())
std::vector<util::utf8string> rawInputs;
if (!inputTXT.empty())
rawInputs = util::readFileAsUtf8(inputTXT, lineByLine);
std::vector<std::vector<std::string>> tsv;
if (!inputTSV.empty())
tsv = util::readTSV(inputTSV);
std::vector<BaseConfig> configs;
if (lineByLine)
{
BaseConfig config(mcd, inputTSV, util::utf8string(), std::vector<int>());
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, true);
if (rawInputs.size())
for (unsigned int i = 0; i < rawInputs.size(); i++)
configs.emplace_back(mcd, tsv, rawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < tsv.size(); i++)
configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
auto inputs = util::readFileAsUtf8(inputTXT, lineByLine);
for (int i = 0; i < (int)inputs.size(); i++)
{
std::vector<int> sentIndexes;
if (inputs.size() > 1)
sentIndexes.emplace_back(i);
BaseConfig config(mcd, inputTSV, inputs[i], sentIndexes);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, i == 0);
}
if (rawInputs.size())
configs.emplace_back(mcd, tsv, rawInputs[0], std::vector<int>());
else
configs.emplace_back(mcd, tsv, util::utf8string(), std::vector<int>());
}
for (unsigned int i = 0; i < configs.size(); i++)
{
decoder.decode(configs[i], beamSize, beamThreshold, debug, printAdvancement);
configs[i].print(stdout, i == 0);
}
} catch(std::exception & e) {util::error(e);}
return 0;
......
......@@ -21,11 +21,11 @@ class BaseConfig : public Config
void createColumns(std::string mcd);
void readRawInput(std::string_view rawFilename);
void readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes);
void readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes);
public :
BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawFilename, const std::vector<int> & sentencesIndexes);
BaseConfig(const BaseConfig & other);
BaseConfig & operator=(const BaseConfig & other) = default;
......
......@@ -43,44 +43,8 @@ void BaseConfig::readRawInput(std::string_view rawFilename)
rawInput.replace(util::utf8char("\t"), util::utf8char(" "));
}
void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<int> & sentencesIndexes)
void BaseConfig::readTSVInput(const std::vector<std::vector<std::string>> & sentences, const std::vector<int> & sentencesIndexes)
{
std::vector<std::vector<std::string>> sentences;
std::FILE * file = std::fopen(tsvFilename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", tsvFilename));
char lineBuffer[100000];
bool inputHasBeenRead = false;
sentences.emplace_back();
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
std::string_view line(lineBuffer);
sentences.back().emplace_back(line);
if (line.size() < 3)
{
if (!inputHasBeenRead)
continue;
sentences.emplace_back();
continue;
}
inputHasBeenRead = true;
}
if (sentences.back().empty())
sentences.pop_back();
std::fclose(file);
for (unsigned int i = 0; i < (sentencesIndexes.size() ? sentencesIndexes.size() : sentences.size()); i++)
{
int targetSentenceIndex = sentencesIndexes.size() ? sentencesIndexes[i] : i;
......@@ -115,7 +79,7 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename, const std::vector<in
if (usualNbCol == -1)
usualNbCol = splited.size();
if ((int)splited.size() != usualNbCol)
util::myThrow(fmt::format("in file {} line {} is invalid, it shoud have {} columns", tsvFilename, line, usualNbCol));
util::myThrow(fmt::format("in tsv file line {} is invalid, it shoud have {} columns", line, usualNbCol));
// Ignore empty nodes
if (hasColIndex(idColName) && splited[getColIndex(idColName)].find('.') != std::string::npos)
......@@ -182,10 +146,10 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
{
}
BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
BaseConfig::BaseConfig(std::string mcd, const std::vector<std::vector<std::string>> & sentences, const util::utf8string & rawInput, const std::vector<int> & sentencesIndexes)
{
if (tsvFilename.empty() and rawInput.empty())
util::myThrow("tsvFilename and rawInput can't be both empty");
if (sentences.empty() and rawInput.empty())
util::myThrow("sentences and rawInput can't be both empty");
createColumns(mcd);
......@@ -193,8 +157,8 @@ BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util
this->rawInput = rawInput;
// sentencesIndexes = index of sentences to keep. Empty vector == keep all sentences.
if (not tsvFilename.empty())
readTSVInput(tsvFilename, sentencesIndexes);
if (not sentences.empty())
readTSVInput(sentences, sentencesIndexes);
if (!has(0,wordIndex,0))
addLines(1);
......
......@@ -117,29 +117,6 @@ inline auto decay(Optimizer &optimizer, double rate) -> void
}
}
int countNbSentences(std::string_view filename)
{
std::FILE * file = std::fopen(filename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", filename));
char lineBuffer[100000];
int nbSent = 0;
while (!std::feof(file))
{
if (lineBuffer != std::fgets(lineBuffer, 100000, file))
break;
if (std::string(lineBuffer).size() < 3)
nbSent++;
}
std::fclose(file);
return nbSent;
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......@@ -188,9 +165,13 @@ int MacaonTrain::main()
try
{
ReadingMachine machine(machinePath.string(), true);
std::vector<std::vector<std::string>> trainTsv, devTsv;
if (!trainTsvFile.empty())
trainTsv = util::readTSV(trainTsvFile);
if (!devTsvFile.empty())
devTsv = util::readTSV(devTsvFile);
int nbSentencesTrain = countNbSentences(trainTsvFile);
ReadingMachine machine(machinePath.string(), true);
std::vector<util::utf8string> trainRawInputs;
if (!trainRawFile.empty())
......@@ -199,23 +180,20 @@ int MacaonTrain::main()
if (lineByLine)
{
if (trainRawInputs.size())
for (int i = 0; i < (int)trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[i], std::vector<int>{i});
for (unsigned int i = 0; i < trainRawInputs.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[i], std::vector<int>{(int)i});
else
for (int i = 0; i < nbSentencesTrain; i++)
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>{i});
for (unsigned int i = 0; i < trainTsv.size(); i++)
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (trainRawInputs.size())
goldConfigs.emplace_back(mcd, trainTsvFile, trainRawInputs[0], std::vector<int>());
goldConfigs.emplace_back(mcd, trainTsv, trainRawInputs[0], std::vector<int>());
else
goldConfigs.emplace_back(mcd, trainTsvFile, util::utf8string(), std::vector<int>());
goldConfigs.emplace_back(mcd, trainTsv, util::utf8string(), std::vector<int>());
}
int nbSentencesDev = countNbSentences(devTsvFile);
std::vector<util::utf8string> devRawInputs;
if (!devRawFile.empty())
devRawInputs = util::readFileAsUtf8(devRawFile, lineByLine);
......@@ -223,18 +201,18 @@ int MacaonTrain::main()
if (lineByLine)
{
if (devRawInputs.size())
for (int i = 0; i < (int)devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[i], std::vector<int>{i});
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (int i = 0; i < nbSentencesDev; i++)
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>{i});
for (unsigned int i = 0; i < devTsv.size(); i++)
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devGoldConfigs.emplace_back(mcd, devTsvFile, devRawInputs[0], std::vector<int>());
devGoldConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
else
devGoldConfigs.emplace_back(mcd, devTsvFile, util::utf8string(), std::vector<int>());
devGoldConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
Trainer trainer(machine, batchSize);
......@@ -322,13 +300,23 @@ int MacaonTrain::main()
if (computeDevScore)
{
std::vector<BaseConfig> devConfigs;
for (int i = 0; i < (int)devRawInputs.size(); i++)
if (devRawInputs.size() == 1)
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>());
if (lineByLine)
{
if (devRawInputs.size())
for (unsigned int i = 0; i < devRawInputs.size(); i++)
devConfigs.emplace_back(mcd, devTsv, devRawInputs[i], std::vector<int>{(int)i});
else
for (unsigned int i = 0; i < devTsv.size(); i++)
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>{(int)i});
}
else
{
if (devRawInputs.size())
devConfigs.emplace_back(mcd, devTsv, devRawInputs[0], std::vector<int>());
else
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInputs[i], std::vector<int>{i});
if (devConfigs.empty())
devConfigs.emplace_back(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, util::utf8string(), std::vector<int>());
devConfigs.emplace_back(mcd, devTsv, util::utf8string(), std::vector<int>());
}
for (auto & devConfig : devConfigs)
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment