Skip to content
Snippets Groups Projects
dev.cpp 1.31 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <filesystem>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
Franck Dary's avatar
Franck Dary committed
#include "Trainer.hpp"
Franck Dary's avatar
Franck Dary committed
#include "Decoder.hpp"
Franck Dary's avatar
Franck Dary committed

int main(int argc, char * argv[])
  if (argc != 8)
    fmt::print(stderr, "needs 7 arguments.\n");
Franck Dary's avatar
Franck Dary committed
  std::string model = argv[1];
  std::string mcdFile = argv[2];
  std::string trainTsvFile = argv[3];
  std::string trainRawFile = "";
  std::string devTsvFile = argv[5];
  std::string devRawFile = "";
  int nbEpoch = std::stoi(argv[7]);
Franck Dary's avatar
Franck Dary committed
  std::filesystem::path modelPath(model);
  auto machinePath = modelPath / "machine.rm";

  ReadingMachine machine(machinePath.string());
  BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
  SubConfig config(goldConfig);

Franck Dary's avatar
Franck Dary committed
  Trainer trainer(machine);
  trainer.createDataset(config);
Franck Dary's avatar
Franck Dary committed
  Decoder decoder(machine);
  BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
Franck Dary's avatar
Franck Dary committed

  for (int i = 0; i < nbEpoch; i++)
Franck Dary's avatar
Franck Dary committed
    float loss = trainer.epoch();
    auto devConfig = devGoldConfig;
Franck Dary's avatar
Franck Dary committed
    decoder.decode(devConfig, 1);
    decoder.evaluate(devConfig, modelPath, devTsvFile);
    fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
Franck Dary's avatar
Franck Dary committed
  return 0;
}