Skip to content
Snippets Groups Projects
Commit e00aa24d authored by Franck Dary's avatar Franck Dary
Browse files

mcd file is now longer needed, we can give mcd throught program argument, use...

mcd file is now longer needed, we can give mcd throught program argument, use the default one, or read it from conllu file metadata
parent a4b28a3e
No related branches found
No related tags found
No related merge requests found
......@@ -14,14 +14,14 @@ po::options_description MacaonDecode::getOptionsDescription()
("inputTSV", po::value<std::string>(),
"File containing the text to decode, TSV file")
("inputTXT", po::value<std::string>(),
"File containing the text to decode, raw text file")
("mcd", po::value<std::string>()->required(),
"Multi Column Description file that describes the input/output format");
"File containing the text to decode, raw text file");
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format")
("beamSize", po::value<int>()->default_value(1),
"Size of the beam during beam search")
("beamThreshold", po::value<float>()->default_value(0.1),
......@@ -72,7 +72,7 @@ int MacaonDecode::main()
auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcdFile = variables["mcd"].as<std::string>();
auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
auto beamSize = variables["beamSize"].as<int>();
......@@ -90,7 +90,7 @@ int MacaonDecode::main()
ReadingMachine machine(machinePath, modelPaths);
Decoder decoder(machine);
BaseConfig config(mcdFile, inputTSV, inputTXT);
BaseConfig config(mcd, inputTSV, inputTXT);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
......
......@@ -19,13 +19,13 @@ class BaseConfig : public Config
private :
void readMCD(std::string_view mcdFilename);
void createColumns(std::string mcd);
void readRawInput(std::string_view rawFilename);
void readTSVInput(std::string_view tsvFilename);
public :
BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename);
BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename);
BaseConfig(const BaseConfig & other);
BaseConfig & operator=(const BaseConfig & other) = default;
......
......@@ -61,6 +61,7 @@ class Config
std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions;
std::shared_ptr<Strategy> strategy;
std::string mcd;
protected :
......
#include "BaseConfig.hpp"
#include "util.hpp"
void BaseConfig::readMCD(std::string_view mcdFilename)
void BaseConfig::createColumns(std::string mcd)
{
if (!colIndex2Name.empty())
util::myThrow("a mcd has already been read for this BaseConfig");
this->mcd = mcd;
std::FILE * file = std::fopen(mcdFilename.data(), "r");
colIndex2Name.clear();
colName2Index.clear();
if (not file)
util::myThrow(fmt::format("Cannot open file '{}'", mcdFilename));
char lineBuffer[1024];
while (std::fscanf(file, "%1023[^\n]\n", lineBuffer) == 1)
auto splited = util::split(mcd, ',');
for (auto & colName : splited)
{
colIndex2Name.emplace_back(lineBuffer);
colName2Index.emplace(lineBuffer, colIndex2Name.size()-1);
colIndex2Name.emplace_back(colName);
colName2Index.emplace(colName, colIndex2Name.size()-1);
}
std::fclose(file);
for (auto & column : extraColumns)
{
if (colName2Index.count(column))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, column));
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcd, column));
colIndex2Name.emplace_back(column);
colName2Index.emplace(column, colIndex2Name.size()-1);
}
......@@ -111,6 +106,12 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
if (line[0] == '#')
{
if (util::doIfNameMatch(std::regex("(?:(?:\\s|\\t)*)#(?:(?:\\s|\\t)*)global.columns(?:(?:\\s|\\t)*)=(?:(?:\\s|\\t)*)(.+)"), line, [this](const auto & sm)
{
createColumns(util::join(",", util::split(util::strip(sm.str(1)), ' ')));
}))
continue;
addLines(1);
get(EOSColName, getNbLines()-1, 0) = EOSSymbol0;
get(isMultiColName, getNbLines()-1, 0) = EOSSymbol0;
......@@ -158,14 +159,12 @@ BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(
{
}
BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename)
BaseConfig::BaseConfig(std::string mcd, std::string_view tsvFilename, std::string_view rawFilename)
{
if (tsvFilename.empty() and rawFilename.empty())
util::myThrow("tsvFilename and rawFilenames can't be both empty");
if (mcdFilename.empty())
util::myThrow("mcdFilename can't be empty");
readMCD(mcdFilename);
createColumns(mcd);
if (not rawFilename.empty())
readRawInput(rawFilename);
......@@ -192,7 +191,7 @@ std::size_t BaseConfig::getColIndex(const std::string & colName) const
{
auto it = colName2Index.find(colName);
if (it == colName2Index.end())
util::myThrow(fmt::format("unknown column name '{}'", colName));
util::myThrow(fmt::format("unknown column name '{}', mcd = '{}'", colName, mcd));
return it->second;
}
......
......@@ -19,7 +19,8 @@ Config::Config(const Config & other)
this->state = other.state;
this->history = other.history;
this->stack = other.stack;
this->extraColumns = this->extraColumns;
this->extraColumns = other.extraColumns;
this->mcd = other.mcd;
}
std::size_t Config::getIndexOfLine(int lineIndex) const
......@@ -106,6 +107,8 @@ void Config::print(FILE * dest) const
currentSequenceComments.clear();
};
fmt::print(dest, "# global.columns = {}\n", util::join(" ", util::split(mcd, ',')));
for (unsigned int line = 0; line < getNbLines(); line++)
{
if (isComment(getFirstLineIndex()+line))
......
......@@ -15,6 +15,7 @@ SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : model(model), s
currentWordId = model.currentWordId;
appliableSplitTransitions = model.appliableSplitTransitions;
appliableTransitions = model.appliableTransitions;
mcd = model.mcd;
if (model.strategy.get() != nullptr)
strategy.reset(new Strategy(model.getStrategy()));
update();
......
......@@ -13,8 +13,6 @@ po::options_description MacaonTrain::getOptionsDescription()
req.add_options()
("model", po::value<std::string>()->required(),
"Directory containing the machine file to train")
("mcd", po::value<std::string>()->required(),
"Multi Column Description file that describes the input format")
("trainTSV", po::value<std::string>()->required(),
"TSV file of the training corpus, in CONLLU format");
......@@ -23,6 +21,8 @@ po::options_description MacaonTrain::getOptionsDescription()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("devScore", "Compute score on dev instead of loss (slower)")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format")
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
......@@ -97,7 +97,7 @@ int MacaonTrain::main()
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / "machine.rm";
auto mcdFile = variables["mcd"].as<std::string>();
auto mcd = variables["mcd"].as<std::string>();
auto trainTsvFile = variables["trainTSV"].as<std::string>();
auto trainRawFile = variables["trainTXT"].as<std::string>();
auto devTsvFile = variables["devTSV"].as<std::string>();
......@@ -132,8 +132,8 @@ int MacaonTrain::main()
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
BaseConfig goldConfig(mcd, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
......@@ -230,7 +230,7 @@ int MacaonTrain::main()
std::vector<std::pair<float,std::string>> devScores;
if (computeDevScore)
{
BaseConfig devConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
BaseConfig devConfig(mcd, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment