-
Franck Dary authored
mcd file is now longer needed, we can give mcd throught program argument, use the default one, or read it from conllu file metadata
Franck Dary authoredmcd file is now longer needed, we can give mcd throught program argument, use the default one, or read it from conllu file metadata
MacaonDecode.cpp 3.44 KiB
#include "MacaonDecode.hpp"
#include <filesystem>
#include "util.hpp"
#include "Decoder.hpp"
po::options_description MacaonDecode::getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("model", po::value<std::string>()->required(),
"Directory containing the trained machine used to decode")
("inputTSV", po::value<std::string>(),
"File containing the text to decode, TSV file")
("inputTXT", po::value<std::string>(),
"File containing the text to decode, raw text file");
po::options_description opt("Optional");
opt.add_options()
("debug,d", "Print debuging infos on stderr")
("silent", "Don't print speed and progress")
("mcd", po::value<std::string>()->default_value("ID,FORM,LEMMA,UPOS,XPOS,FEATS,HEAD,DEPREL"),
"Comma separated column names that describes the input/output format")
("beamSize", po::value<int>()->default_value(1),
"Size of the beam during beam search")
("beamThreshold", po::value<float>()->default_value(0.1),
"Minimal probability an action must have to be considered in the beam search")
("help,h", "Produce this help message");
desc.add(req).add(opt);
return desc;
}
po::variables_map MacaonDecode::checkOptions(po::options_description & od)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception & e) {util::myThrow(e.what());}
if (vm.count("help"))
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "{}\n", ss.str());
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e) {util::myThrow(e.what());}
if (vm.count("inputTSV") + vm.count("inputTXT") != 1)
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "Error : one and only one input format must be specified.\n{}\n", ss.str());
exit(1);
}
return vm;
}
int MacaonDecode::main()
{
auto od = getOptionsDescription();
auto variables = checkOptions(od);
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcd = variables["mcd"].as<std::string>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>();
torch::globalContext().setBenchmarkCuDNN(true);
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
fmt::print(stderr, "Decoding using device : {}\n", NeuralNetworkImpl::device.str());
try
{
ReadingMachine machine(machinePath, modelPaths);
Decoder decoder(machine);
BaseConfig config(mcd, inputTSV, inputTXT);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout);
} catch(std::exception & e) {util::error(e);}
return 0;
}
MacaonDecode::MacaonDecode(int argc, char ** argv) : argc(argc), argv(argv)
{
}