Skip to content
Snippets Groups Projects
dev.cpp 4.96 KiB
Newer Older
Franck Dary's avatar
Franck Dary committed
#include <cstdio>
#include "fmt/core.h"
Franck Dary's avatar
Franck Dary committed
#include "util.hpp"
#include "BaseConfig.hpp"
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
constexpr int nbClasses = 15;
constexpr int nbWordsPerDatapoint = 5;
constexpr int maxNbEmbeddings = 1000000;
Franck Dary's avatar
Franck Dary committed
//3m15s
struct NetworkImpl : torch::nn::Module
{
  torch::nn::Linear linear{nullptr};
  torch::nn::Embedding wordEmbeddings{nullptr};
  NetworkImpl()
Franck Dary's avatar
Franck Dary committed
    linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
    wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
Franck Dary's avatar
Franck Dary committed
  };
  torch::Tensor forward(const torch::Tensor & input)
  {
    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
    return torch::softmax(linear(embeddingsOfInput),1);
  }
};
TORCH_MODULE(Network);
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
int main(int argc, char * argv[])
{
  auto nn = Network();
Franck Dary's avatar
Franck Dary committed
  torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
  torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
Franck Dary's avatar
Franck Dary committed
  std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
    batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
Franck Dary's avatar
Franck Dary committed

Franck Dary's avatar
Franck Dary committed
  for (auto & batch : batches)
  {
Franck Dary's avatar
Franck Dary committed
    sparseOptimizer.zero_grad();
    denseOptimizer.zero_grad();
Franck Dary's avatar
Franck Dary committed
    auto prediction = nn(batch.first);
    auto loss = torch::nll_loss(torch::log(prediction), batch.second);
    loss.backward();
Franck Dary's avatar
Franck Dary committed
    sparseOptimizer.step();
    denseOptimizer.step();
Franck Dary's avatar
Franck Dary committed
  return 0;
}

Franck Dary's avatar
Franck Dary committed
//int main(int argc, char * argv[])
//{
//  if (argc != 5)
//  {
//    fmt::print(stderr, "needs 4 arguments.\n");
//    exit(1);
//  }
//
//  at::init_num_threads();
//
//  std::string machineFile = argv[1];
//  std::string mcdFile = argv[2];
//  std::string tsvFile = argv[3];
//  //std::string rawFile = argv[4];
//  std::string rawFile = "";
//
//  ReadingMachine machine(machineFile);
//
//  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
//  SubConfig config(goldConfig);
//
//  config.setState(machine.getStrategy().getInitialState());
//
//  std::vector<torch::Tensor> contexts;
//  std::vector<torch::Tensor> classes;
//
//  fmt::print("Generating dataset...\n");
//
//  Dict dict(Dict::State::Open);
//
//  while (true)
//  {
//    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
//    if (!transition)
//      util::myThrow("No transition appliable !");
//
//    auto context = config.extractContext(5,5,dict);
//    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
//    int goldIndex = 3;
//    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
//    classes.emplace_back(gold);
//
//    transition->apply(config);
//    config.addToHistory(transition->getName());
//
//    auto movement = machine.getStrategy().getMovement(config, transition->getName());
//    if (movement == Strategy::endMovement)
//      break;
//
//    config.setState(movement.first);
//    if (!config.moveWordIndex(movement.second))
//      util::myThrow("Cannot move word index !");
//
//    if (config.needsUpdate())
//      config.update();
//  }
//
//  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
//  int nbExamples = *dataset.size();
//  fmt::print("Done! size={}\n", nbExamples);
//
//  int batchSize = 100;
//  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
//  TestNetwork nn(machine.getTransitionSet().size(), 5);
//  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
//  for (int epoch = 1; epoch <= 1; ++epoch)
//  {
//    float totalLoss = 0.0;
//    torch::Tensor example;
//    int currentBatchNumber = 0;
//
//    for (auto & batch : *dataLoader)
//    {
//      optimizer.zero_grad();
//
//      auto data = batch.data;
//      auto labels = batch.target.squeeze();
//
//      auto prediction = nn(data);
//      example = prediction[0];
//
//      auto loss = torch::nll_loss(torch::log(prediction), labels);
//      totalLoss += loss.item<float>();
//      loss.backward();
//      optimizer.step();
//
//      if (++currentBatchNumber*batchSize % 1000 == 0)
//      {
//        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
//        std::fflush(stdout);
//      }
//    }
//
//    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
//  }
//
//  return 0;
//}
//