Skip to content
Snippets Groups Projects
Commit f4f14edf authored by Franck Dary's avatar Franck Dary
Browse files

Added option lineByLine for macaon decode, which treat each line of the raw...

Added option lineByLine for macaon decode, which treat each line of the raw input as a different text
parent 363d5cc9
Branches
No related tags found
No related merge requests found
...@@ -106,6 +106,8 @@ utf8string upper(const utf8string & s); ...@@ -106,6 +106,8 @@ utf8string upper(const utf8string & s);
void upper(utf8char & c); void upper(utf8char & c);
std::vector<utf8string> readFileAsUtf8(std::string_view filename, bool lineByLine);
template <typename T> template <typename T>
std::string join(const std::string & delim, const std::vector<T> elems) std::string join(const std::string & delim, const std::vector<T> elems)
{ {
......
...@@ -318,3 +318,50 @@ void util::upper(utf8char & c) ...@@ -318,3 +318,50 @@ void util::upper(utf8char & c)
c = it->second; c = it->second;
} }
std::vector<util::utf8string> util::readFileAsUtf8(std::string_view filename, bool lineByLine)
{
std::vector<utf8string> res;
std::FILE * file = std::fopen(filename.data(), "r");
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", filename));
std::string lineTemp;
if (!lineByLine)
{
while (not std::feof(file))
lineTemp.push_back(std::fgetc(file));
auto line = util::splitAsUtf8(lineTemp);
line.replace(util::utf8char("\n"), util::utf8char(" "));
line.replace(util::utf8char("\t"), util::utf8char(" "));
res.emplace_back(line);
}
else
{
while (not std::feof(file))
{
lineTemp.clear();
while (not std::feof(file))
{
lineTemp.push_back(std::fgetc(file));
if (lineTemp.back() == '\n')
break;
}
auto line = util::splitAsUtf8(lineTemp);
line.replace(util::utf8char("\n"), util::utf8char(" "));
line.replace(util::utf8char("\t"), util::utf8char(" "));
if (!line.empty())
res.emplace_back(line);
}
}
std::fclose(file);
return res;
}
...@@ -22,6 +22,7 @@ po::options_description MacaonDecode::getOptionsDescription() ...@@ -22,6 +22,7 @@ po::options_description MacaonDecode::getOptionsDescription()
("debug,d", "Print debuging infos on stderr") ("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress") ("silent", "Don't print speed and progress")
("reloadEmbeddings", "Reload pretrained embeddings") ("reloadEmbeddings", "Reload pretrained embeddings")
("lineByLine", "Treat the TXT input as being one different text per line.")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"), ("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format") "Comma separated column names that describes the input/output format")
("beamSize", po::value<int>()->default_value(1), ("beamSize", po::value<int>()->default_value(1),
...@@ -78,6 +79,7 @@ int MacaonDecode::main() ...@@ -78,6 +79,7 @@ int MacaonDecode::main()
bool debug = variables.count("debug") == 0 ? false : true; bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false; bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true; bool reloadPretrained = variables.count("reloadEmbeddings") == 0 ? false : true;
bool lineByLine = variables.count("lineByLine") == 0 ? false : true;
auto beamSize = variables["beamSize"].as<int>(); auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>(); auto beamThreshold = variables["beamThreshold"].as<float>();
...@@ -94,11 +96,22 @@ int MacaonDecode::main() ...@@ -94,11 +96,22 @@ int MacaonDecode::main()
ReadingMachine machine(machinePath, false); ReadingMachine machine(machinePath, false);
Decoder decoder(machine); Decoder decoder(machine);
BaseConfig config(mcd, inputTSV, inputTXT); if (inputTXT.empty())
{
BaseConfig config(mcd, inputTSV, util::utf8string());
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement); decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, true);
config.print(stdout); }
else
{
auto inputs = util::readFileAsUtf8(inputTXT, lineByLine);
for (unsigned int i = 0; i < inputs.size(); i++)
{
BaseConfig config(mcd, inputTSV, inputs[i]);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout, i == 0);
}
}
} catch(std::exception & e) {util::error(e);} } catch(std::exception & e) {util::error(e);}
return 0; return 0;
......
...@@ -25,7 +25,7 @@ class BaseConfig : public Config ...@@ -25,7 +25,7 @@ class BaseConfig : public Config
public : public :
BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename); BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawFilename);
BaseConfig(const BaseConfig & other); BaseConfig(const BaseConfig & other);
BaseConfig & operator=(const BaseConfig & other) = default; BaseConfig & operator=(const BaseConfig & other) = default;
......
...@@ -107,7 +107,7 @@ class Config ...@@ -107,7 +107,7 @@ class Config
public : public :
void print(FILE * dest) const; void print(FILE * dest, bool printHeader = true) const;
void printForDebug(FILE * dest) const; void printForDebug(FILE * dest) const;
bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const; bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const;
String & get(const std::string & colName, int lineIndex, int hypothesisIndex); String & get(const std::string & colName, int lineIndex, int hypothesisIndex);
......
...@@ -161,15 +161,15 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name( ...@@ -161,15 +161,15 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
{ {
} }
BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename) BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, const util::utf8string & rawInput)
{ {
if (tsvFilename.empty() and rawFilename.empty()) if (tsvFilename.empty() and rawInput.empty())
util::myThrow("tsvFilename and rawFilenames can't be both empty"); util::myThrow("tsvFilename and rawInput can't be both empty");
createColumns(mcd); createColumns(mcd);
if (not rawFilename.empty()) if (not rawInput.empty())
readRawInput(rawFilename); this->rawInput = rawInput;
if (not tsvFilename.empty()) if (not tsvFilename.empty())
readTSVInput(tsvFilename); readTSVInput(tsvFilename);
......
...@@ -78,7 +78,7 @@ std::size_t Config::getNbLines() const ...@@ -78,7 +78,7 @@ std::size_t Config::getNbLines() const
return lines.size() / getIndexOfCol(getNbColumns()); return lines.size() / getIndexOfCol(getNbColumns());
} }
void Config::print(FILE * dest) const void Config::print(FILE * dest, bool printHeader) const
{ {
std::vector<std::string> currentSequence; std::vector<std::string> currentSequence;
std::vector<std::string> currentSequenceComments; std::vector<std::string> currentSequenceComments;
...@@ -100,6 +100,7 @@ void Config::print(FILE * dest) const ...@@ -100,6 +100,7 @@ void Config::print(FILE * dest) const
currentSequenceComments.clear(); currentSequenceComments.clear();
}; };
if (printHeader)
fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ','))); fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ',')));
for (unsigned int line = 0; line < getNbLines(); line++) for (unsigned int line = 0; line < getNbLines(); line++)
......
...@@ -165,8 +165,20 @@ int MacaonTrain::main() ...@@ -165,8 +165,20 @@ int MacaonTrain::main()
ReadingMachine machine(machinePath.string(), true); ReadingMachine machine(machinePath.string(), true);
BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile); util::utf8string trainRawInput;
BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); if (!trainRawFile.empty())
{
auto input = util::readFileAsUtf8(trainRawFile, false);
trainRawInput = input[0];
}
BaseConfig goldConfig(mcd, trainTsvFile, trainRawInput);
util::utf8string devRawInput;
if (!devRawFile.empty())
{
auto input = util::readFileAsUtf8(devRawFile, false);
devRawInput = input[0];
}
BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput);
Trainer trainer(machine, batchSize); Trainer trainer(machine, batchSize);
Decoder decoder(machine); Decoder decoder(machine);
...@@ -251,7 +263,7 @@ int MacaonTrain::main() ...@@ -251,7 +263,7 @@ int MacaonTrain::main()
std::vector<std::pair<float,std::string>> devScores; std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore) if (computeDevScore)
{ {
BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawInput);
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement); decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
decoder.evaluate(devConfig, modelPath, devTsvFile, machine.getPredicted()); decoder.evaluate(devConfig, modelPath, devTsvFile, machine.getPredicted());
devScores = decoder.getF1Scores(machine.getPredicted()); devScores = decoder.getF1Scores(machine.getPredicted());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment