Skip to content
Snippets Groups Projects
dev.cpp 2.93 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

int main(int argc, char * argv[])
{
  if (argc != 5)
  {
    fmt::print(stderr, "needs 4 arguments.\n");
    exit(1);
  }

  std::string machineFile = argv[1];
  std::string mcdFile = argv[2];
  std::string tsvFile = argv[3];
  //std::string rawFile = argv[4];
  std::string rawFile = "";

  ReadingMachine machine(machineFile);
  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
Franck Dary's avatar
Franck Dary committed
  SubConfig config(goldConfig);
  config.setState(machine.getStrategy().getInitialState());
Franck Dary's avatar
Franck Dary committed
  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;
Franck Dary's avatar
Franck Dary committed
  fmt::print("Generating dataset...\n");

  Dict dict(Dict::State::Open);
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
      util::myThrow("No transition appliable !");

Franck Dary's avatar
Franck Dary committed
    auto context = config.extractContext(5,5,dict);
    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
Franck Dary's avatar
Franck Dary committed
    int goldIndex = 3;
    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
Franck Dary's avatar
Franck Dary committed
    classes.emplace_back(gold);
    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");

    if (config.needsUpdate())
      config.update();
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());

Franck Dary's avatar
Franck Dary committed
  int nbExamples = *dataset.size();
  fmt::print("Done! size={}\n", nbExamples);
Franck Dary's avatar
Franck Dary committed
  int batchSize = 100;
Franck Dary's avatar
Franck Dary committed
  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
Franck Dary's avatar
Franck Dary committed
  TestNetwork nn(machine.getTransitionSet().size(), 5);
  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
Franck Dary's avatar
Franck Dary committed
  for (int epoch = 1; epoch <= 1; ++epoch)
Franck Dary's avatar
Franck Dary committed
    float totalLoss = 0.0;
    torch::Tensor example;
Franck Dary's avatar
Franck Dary committed
    int currentBatchNumber = 0;
Franck Dary's avatar
Franck Dary committed

    for (auto & batch : *dataLoader)
    {
      optimizer.zero_grad();

      auto data = batch.data;
      auto labels = batch.target.squeeze();

      auto prediction = nn(data);
      example = prediction[0];

      auto loss = torch::nll_loss(torch::log(prediction), labels);
      totalLoss += loss.item<float>();
      loss.backward();
      optimizer.step();
Franck Dary's avatar
Franck Dary committed

      if (++currentBatchNumber*batchSize % 1000 == 0)
      {
        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
        std::fflush(stdout);
      }
Franck Dary's avatar
Franck Dary committed
    }

    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
Franck Dary's avatar
Franck Dary committed
  return 0;
}