Skip to content
Snippets Groups Projects
dev.cpp 2.57 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

int main(int argc, char * argv[])
{
  if (argc != 5)
  {
    fmt::print(stderr, "needs 4 arguments.\n");
    exit(1);
  }

  std::string machineFile = argv[1];
  std::string mcdFile = argv[2];
  std::string tsvFile = argv[3];
  //std::string rawFile = argv[4];
  std::string rawFile = "";

  ReadingMachine machine(machineFile);
  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
Franck Dary's avatar
Franck Dary committed
  SubConfig config(goldConfig);
  config.setState(machine.getStrategy().getInitialState());


  TestNetwork nn(machine.getTransitionSet().size());
  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
  optimizer.zero_grad();

  std::vector<torch::Tensor> predictionsBatch;
  std::vector<torch::Tensor> referencesBatch;
  std::vector<std::unique_ptr<Config>> configs;
  std::vector<std::size_t> classes;

  fmt::print("Generating dataset...");

  Dict dict(Dict::State::Open);
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
      util::myThrow("No transition appliable !");

    //here train
    int goldIndex = 3;
    auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
    gold[goldIndex] = 1;
//    referencesBatch.emplace_back(gold);
//    predictionsBatch.emplace_back(nn(config));

//    auto loss = torch::nll_loss(prediction, gold);
//    loss.backward();
//    optimizer.step();
    configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config)));
    classes.emplace_back(goldIndex);
//    if (config.getWordIndex() >= 500)
//      exit(1);
    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");

    if (config.needsUpdate())
      config.update();
Franck Dary's avatar
Franck Dary committed

  auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>());

  fmt::print("Done!\n");

  auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50);

  for (auto & batch : *dataLoader)
  {
    auto data = batch.data;
    auto labels = batch.target.squeeze();
  }

Franck Dary's avatar
Franck Dary committed
  return 0;
}