From 05bf505ee71f9050487557024229a8aeeaa57535 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Fri, 7 Feb 2020 15:57:18 +0100 Subject: [PATCH] Made macaon_train executable --- CMakeLists.txt | 5 ++ trainer/CMakeLists.txt | 7 +++ trainer/src/macaon_train.cpp | 97 ++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 trainer/src/macaon_train.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 302d7f3..6fed182 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,11 +4,16 @@ project(test_torch) add_compile_definitions(BOOST_DISABLE_THREADS) find_package(Torch REQUIRED) +find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) add_library(Torch SHARED IMPORTED) set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES}) +add_library(Boost SHARED IMPORTED) +set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE}) + +string(APPEND CMAKE_INSTALL_RPATH ":${TORCH_INSTALL_PREFIX}/lib") set(CMAKE_VERBOSE_MAKEFILE 0) set(CMAKE_CXX_STANDARD 17) diff --git a/trainer/CMakeLists.txt b/trainer/CMakeLists.txt index b673afa..649eae0 100644 --- a/trainer/CMakeLists.txt +++ b/trainer/CMakeLists.txt @@ -2,4 +2,11 @@ FILE(GLOB SOURCES src/*.cpp) add_library(trainer STATIC ${SOURCES}) target_link_libraries(trainer reading_machine) +target_link_libraries(trainer torch_modules) +add_executable(macaon_train src/macaon_train.cpp) +target_link_libraries(macaon_train Boost) +target_link_libraries(macaon_train trainer) +target_link_libraries(macaon_train decoder) +target_link_libraries(macaon_train common) +install(TARGETS macaon_train DESTINATION bin) diff --git a/trainer/src/macaon_train.cpp b/trainer/src/macaon_train.cpp new file mode 100644 index 0000000..6b817a5 --- /dev/null +++ b/trainer/src/macaon_train.cpp @@ -0,0 +1,97 @@ +#include <boost/program_options.hpp> +#include <filesystem> +#include "util.hpp" +#include "Trainer.hpp" +#include "Decoder.hpp" + +namespace po = boost::program_options; + +po::options_description getOptionsDescription() +{ + po::options_description desc("Command-Line Arguments "); + + po::options_description req("Required"); + req.add_options() + ("expName", po::value<std::string>()->required(), + "Name of this experiment") + ("model", po::value<std::string>()->required(), + "Directory containing the machine file to train") + ("trainTSV", po::value<std::string>()->required(), + "TSV file of the training corpus, in CONLLU format"); + + po::options_description opt("Optional"); + opt.add_options() + ("trainTXT", po::value<std::string>()->default_value(""), + "Raw text file of the training corpus") + ("devTSV", po::value<std::string>()->default_value(""), + "TSV file of the development corpus, in CONLLU format") + ("devTXT", po::value<std::string>()->default_value(""), + "Raw text file of the development corpus") + ("nbEpochs,n", po::value<int>()->default_value(5), + "Number of training epochs") + ("help,h", "Produce this help message"); + + desc.add(req).add(opt); + + return desc; +} + +po::variables_map checkOptions(po::options_description & od, int argc, char ** argv) +{ + po::variables_map vm; + + try {po::store(po::parse_command_line(argc, argv, od), vm);} + catch(std::exception & e) {util::myThrow(e.what());} + + if (vm.count("help")) + { + std::stringstream ss; + ss << od; + fmt::print(stderr, "{}\n", ss.str()); + exit(0); + } + + try {po::notify(vm);} + catch(std::exception& e) {util::myThrow(e.what());} + + return vm; +} + +int main(int argc, char * argv[]) +{ + auto od = getOptionsDescription(); + auto variables = checkOptions(od, argc, argv); + + auto expName = variables["expName"].as<std::string>(); + std::filesystem::path modelPath(variables["model"].as<std::string>()); + auto machinePath = modelPath / "machine.rm"; + auto mcdFile = variables["mcd"].as<std::string>(); + auto trainTsvFile = variables["trainTSV"].as<std::string>(); + auto trainRawFile = variables["trainTXT"].as<std::string>(); + auto devTsvFile = variables["devTSV"].as<std::string>(); + auto devRawFile = variables["devTXT"].as<std::string>(); + auto nbEpoch = variables["nbEpochs"].as<int>(); + + ReadingMachine machine(machinePath.string()); + + BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile); + SubConfig config(goldConfig); + + Trainer trainer(machine); + trainer.createDataset(config); + + Decoder decoder(machine); + BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile); + + for (int i = 0; i < nbEpoch; i++) + { + float loss = trainer.epoch(); + auto devConfig = devGoldConfig; + decoder.decode(devConfig, 1); + decoder.evaluate(devConfig, modelPath, devTsvFile); + fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS")); + } + + return 0; +} + -- GitLab