Skip to content
Snippets Groups Projects
dev.cpp 2.7 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

int main(int argc, char * argv[])
{
  if (argc != 5)
  {
    fmt::print(stderr, "needs 4 arguments.\n");
    exit(1);
  }

  std::string machineFile = argv[1];
  std::string mcdFile = argv[2];
  std::string tsvFile = argv[3];
  //std::string rawFile = argv[4];
  std::string rawFile = "";

  ReadingMachine machine(machineFile);
  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
Franck Dary's avatar
Franck Dary committed
  SubConfig config(goldConfig);
  config.setState(machine.getStrategy().getInitialState());
Franck Dary's avatar
Franck Dary committed
  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;
Franck Dary's avatar
Franck Dary committed
  fmt::print("Generating dataset...\n");

  Dict dict(Dict::State::Open);
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
      util::myThrow("No transition appliable !");

Franck Dary's avatar
Franck Dary committed
    auto context = config.extractContext(5,5,dict);
    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
Franck Dary's avatar
Franck Dary committed
    int goldIndex = 3;
    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
Franck Dary's avatar
Franck Dary committed
    classes.emplace_back(gold);
    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");

    if (config.needsUpdate())
      config.update();
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());

  fmt::print("Done! size={}\n", *dataset.size());
Franck Dary's avatar
Franck Dary committed
  int batchSize = 100;
  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize));
Franck Dary's avatar
Franck Dary committed
  TestNetwork nn(machine.getTransitionSet().size(), 5);
  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
Franck Dary's avatar
Franck Dary committed
  for (int epoch = 1; epoch <= 5; ++epoch)
Franck Dary's avatar
Franck Dary committed
    float totalLoss = 0.0;
    torch::Tensor example;

    for (auto & batch : *dataLoader)
    {
      optimizer.zero_grad();

      auto data = batch.data;
      auto labels = batch.target.squeeze();

      auto prediction = nn(data);
      example = prediction[0];

      auto loss = torch::nll_loss(torch::log(prediction), labels);
      totalLoss += loss.item<float>();
      loss.backward();
      optimizer.step();
    }

    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
    std::cout << example << std::endl;
Franck Dary's avatar
Franck Dary committed
  return 0;
}