Skip to content
Snippets Groups Projects
dev.cpp 3.32 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

int main(int argc, char * argv[])
  if (argc != 5)
    fmt::print(stderr, "needs 4 arguments.\n");
    exit(1);
  }

  at::init_num_threads();

  std::string machineFile = argv[1];
  std::string mcdFile = argv[2];
  std::string tsvFile = argv[3];
  //std::string rawFile = argv[4];
  std::string rawFile = "";

  ReadingMachine machine(machineFile);

  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
  SubConfig config(goldConfig);

  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  fmt::print("Generating dataset...\n");

  Dict dict(Dict::State::Open);

  while (true)
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
      util::myThrow("No transition appliable !");

    auto context = config.extractContext(5,5,dict);
    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());

    int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");

    if (config.needsUpdate())
      config.update();
Franck Dary's avatar
Franck Dary committed

  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());

  int nbExamples = *dataset.size();
  fmt::print("Done! size={}\n", nbExamples);

  int batchSize = 100;
  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));

  TestNetwork nn(machine.getTransitionSet().size(), 5);
  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-3).beta1(0.5));
  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-3).beta1(0.5));
  for (int epoch = 1; epoch <= 30; ++epoch)
    float totalLoss = 0.0;
    float lossSoFar = 0.0;
    torch::Tensor example;
    int currentBatchNumber = 0;

    for (auto & batch : *dataLoader)
    {
      denseOptimizer.zero_grad();
      sparseOptimizer.zero_grad();

      auto data = batch.data;
      auto labels = batch.target.squeeze();

      auto prediction = nn(data);
      example = prediction[0];

      auto loss = torch::nll_loss(torch::log(prediction), labels);
      totalLoss += loss.item<float>();
      lossSoFar += loss.item<float>();
      loss.backward();
      denseOptimizer.step();
      sparseOptimizer.step();

      if (++currentBatchNumber*batchSize % 1000 == 0)
      {
        fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
        std::fflush(stdout);
        lossSoFar = 0;
      }
    }

    fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
Franck Dary's avatar
Franck Dary committed
  return 0;
}