Skip to content
Snippets Groups Projects
Commit 05bf505e authored by Franck Dary's avatar Franck Dary
Browse files

Made macaon_train executable

parent 05acae81
No related branches found
No related tags found
No related merge requests found
...@@ -4,11 +4,16 @@ project(test_torch) ...@@ -4,11 +4,16 @@ project(test_torch)
add_compile_definitions(BOOST_DISABLE_THREADS) add_compile_definitions(BOOST_DISABLE_THREADS)
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
find_package(Boost 1.53.0 REQUIRED COMPONENTS program_options)
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS}) include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
add_library(Torch SHARED IMPORTED) add_library(Torch SHARED IMPORTED)
set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES}) set_target_properties(Torch PROPERTIES IMPORTED_LOCATION ${TORCH_LIBRARIES})
add_library(Boost SHARED IMPORTED)
set_target_properties(Boost PROPERTIES IMPORTED_LOCATION ${Boost_PROGRAM_OPTIONS_LIBRARY_RELEASE})
string(APPEND CMAKE_INSTALL_RPATH ":${TORCH_INSTALL_PREFIX}/lib")
set(CMAKE_VERBOSE_MAKEFILE 0) set(CMAKE_VERBOSE_MAKEFILE 0)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
......
...@@ -2,4 +2,11 @@ FILE(GLOB SOURCES src/*.cpp) ...@@ -2,4 +2,11 @@ FILE(GLOB SOURCES src/*.cpp)
add_library(trainer STATIC ${SOURCES}) add_library(trainer STATIC ${SOURCES})
target_link_libraries(trainer reading_machine) target_link_libraries(trainer reading_machine)
target_link_libraries(trainer torch_modules)
add_executable(macaon_train src/macaon_train.cpp)
target_link_libraries(macaon_train Boost)
target_link_libraries(macaon_train trainer)
target_link_libraries(macaon_train decoder)
target_link_libraries(macaon_train common)
install(TARGETS macaon_train DESTINATION bin)
#include <boost/program_options.hpp>
#include <filesystem>
#include "util.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"
namespace po = boost::program_options;
po::options_description getOptionsDescription()
{
po::options_description desc("Command-Line Arguments ");
po::options_description req("Required");
req.add_options()
("expName", po::value<std::string>()->required(),
"Name of this experiment")
("model", po::value<std::string>()->required(),
"Directory containing the machine file to train")
("trainTSV", po::value<std::string>()->required(),
"TSV file of the training corpus, in CONLLU format");
po::options_description opt("Optional");
opt.add_options()
("trainTXT", po::value<std::string>()->default_value(""),
"Raw text file of the training corpus")
("devTSV", po::value<std::string>()->default_value(""),
"TSV file of the development corpus, in CONLLU format")
("devTXT", po::value<std::string>()->default_value(""),
"Raw text file of the development corpus")
("nbEpochs,n", po::value<int>()->default_value(5),
"Number of training epochs")
("help,h", "Produce this help message");
desc.add(req).add(opt);
return desc;
}
po::variables_map checkOptions(po::options_description & od, int argc, char ** argv)
{
po::variables_map vm;
try {po::store(po::parse_command_line(argc, argv, od), vm);}
catch(std::exception & e) {util::myThrow(e.what());}
if (vm.count("help"))
{
std::stringstream ss;
ss << od;
fmt::print(stderr, "{}\n", ss.str());
exit(0);
}
try {po::notify(vm);}
catch(std::exception& e) {util::myThrow(e.what());}
return vm;
}
int main(int argc, char * argv[])
{
auto od = getOptionsDescription();
auto variables = checkOptions(od, argc, argv);
auto expName = variables["expName"].as<std::string>();
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / "machine.rm";
auto mcdFile = variables["mcd"].as<std::string>();
auto trainTsvFile = variables["trainTSV"].as<std::string>();
auto trainRawFile = variables["trainTXT"].as<std::string>();
auto devTsvFile = variables["devTSV"].as<std::string>();
auto devRawFile = variables["devTXT"].as<std::string>();
auto nbEpoch = variables["nbEpochs"].as<int>();
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
SubConfig config(goldConfig);
Trainer trainer(machine);
trainer.createDataset(config);
Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch();
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1);
decoder.evaluate(devConfig, modelPath, devTsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
}
return 0;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment