Skip to content
Snippets Groups Projects
dev.cpp 5.71 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

//constexpr int batchSize = 50;
//constexpr int nbExamples = 350000;
//constexpr int embeddingSize = 20;
//constexpr int nbClasses = 15;
//constexpr int nbWordsPerDatapoint = 5;
//constexpr int maxNbEmbeddings = 1000000;
//
//struct NetworkImpl : torch::nn::Module
//{
//  torch::nn::Linear linear{nullptr};
//  torch::nn::Embedding wordEmbeddings{nullptr};
//
//  std::vector<torch::Tensor> _sparseParameters;
//  std::vector<torch::Tensor> _denseParameters;
//  NetworkImpl()
//  {
//    linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
//    auto params = linear->parameters();
//    _denseParameters.insert(_denseParameters.end(), params.begin(), params.end());
//
//    wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
//    params = wordEmbeddings->parameters();
//    _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end());
//  };
//  const std::vector<torch::Tensor> & denseParameters()
//  {
//    return _denseParameters;
//  }
//  const std::vector<torch::Tensor> & sparseParameters()
//  {
//    return _sparseParameters;
//  }
//  torch::Tensor forward(const torch::Tensor & input)
//  {
//    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
//    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
//    return torch::softmax(linear(embeddingsOfInput),1);
//  }
//};
//TORCH_MODULE(Network);

//int main(int argc, char * argv[])
//{
//  auto nn = Network();
//  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
//  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//  std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
//  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
//    batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
//
//  for (auto & batch : batches)
//  {
//    sparseOptimizer.zero_grad();
//    denseOptimizer.zero_grad();
//    auto prediction = nn(batch.first);
//    auto loss = torch::nll_loss(torch::log(prediction), batch.second);
//    loss.backward();
//    sparseOptimizer.step();
//    denseOptimizer.step();
//  }
//  return 0;
//}

int main(int argc, char * argv[])
  if (argc != 5)
    fmt::print(stderr, "needs 4 arguments.\n");
    exit(1);
  }

  at::init_num_threads();

  std::string machineFile = argv[1];
  std::string mcdFile = argv[2];
  std::string tsvFile = argv[3];
  //std::string rawFile = argv[4];
  std::string rawFile = "";

  ReadingMachine machine(machineFile);

  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
  SubConfig config(goldConfig);

  config.setState(machine.getStrategy().getInitialState());

  std::vector<torch::Tensor> contexts;
  std::vector<torch::Tensor> classes;

  fmt::print("Generating dataset...\n");

  Dict dict(Dict::State::Open);

  while (true)
    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    if (!transition)
      util::myThrow("No transition appliable !");

    auto context = config.extractContext(5,5,dict);
    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());

    int goldIndex = 3;
    auto gold = torch::zeros(1, at::kLong);
    gold[0] = goldIndex;

    classes.emplace_back(gold);

    transition->apply(config);
    config.addToHistory(transition->getName());

    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    if (movement == Strategy::endMovement)
      break;

    config.setState(movement.first);
    if (!config.moveWordIndex(movement.second))
      util::myThrow("Cannot move word index !");

    if (config.needsUpdate())
      config.update();
Franck Dary's avatar
Franck Dary committed

  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());

  int nbExamples = *dataset.size();
  fmt::print("Done! size={}\n", nbExamples);

  int batchSize = 100;
  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));

  TestNetwork nn(machine.getTransitionSet().size(), 5);
  torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5));
  torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5));

  for (int epoch = 1; epoch <= 2; ++epoch)
    float totalLoss = 0.0;
    float lossSoFar = 0.0;
    torch::Tensor example;
    int currentBatchNumber = 0;

    for (auto & batch : *dataLoader)
    {
      denseOptimizer.zero_grad();
      sparseOptimizer.zero_grad();

      auto data = batch.data;
      auto labels = batch.target.squeeze();

      auto prediction = nn(data);
      example = prediction[0];

      auto loss = torch::nll_loss(torch::log(prediction), labels);
      totalLoss += loss.item<float>();
      lossSoFar += loss.item<float>();
      loss.backward();
      denseOptimizer.step();
      sparseOptimizer.step();

      if (++currentBatchNumber*batchSize % 1000 == 0)
      {
        fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar);
        std::fflush(stdout);
        lossSoFar = 0;
      }
    }

    fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss);
Franck Dary's avatar
Franck Dary committed
  return 0;
}