Skip to content
Snippets Groups Projects
dev.cpp 8.62 KiB
Newer Older
  • Learn to ignore specific revisions
  • Franck Dary's avatar
    Franck Dary committed
    #include <cstdio>
    
    #include "fmt/core.h"
    
    Franck Dary's avatar
    Franck Dary committed
    #include "util.hpp"
    
    #include "BaseConfig.hpp"
    #include "SubConfig.hpp"
    
    #include "TransitionSet.hpp"
    
    #include "TestNetwork.hpp"
    
    #include "ConfigDataset.hpp"
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    namespace torch
    {
    namespace optim
    
    Franck Dary's avatar
    Franck Dary committed
    {
    
    Franck Dary's avatar
    Franck Dary committed
    
    class SparseAdam : public Optimizer
    {
      public:
    
      template <typename ParameterContainer>
      explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
          : Optimizer(std::forward<ParameterContainer>(parameters)),
            options(options_)
    
    Franck Dary's avatar
    Franck Dary committed
      void step() override
      {
        for (size_t i = 0; i < parameters_.size(); ++i)
        {
          Tensor p = parameters_.at(i);
          if (!p.grad().defined())
            continue;
    
          auto& exp_average = buffer_at(exp_average_buffers, i);
          auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
      
          buffer_at(step_buffers, i) += 1;
          const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
          const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
          if (p.grad().is_sparse())
          {
            NoGradGuard guard;
            p.grad() = p.grad().coalesce();
            auto indices = p.grad().indices().squeeze();
            auto values = p.grad().values();
    
    Franck Dary's avatar
    Franck Dary committed
            auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
            auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
            for (unsigned int j = 0; j < indices.size(0); j++)
              exp_average[indices[j].item<long>()] += exp_average_update_values[j];
            auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
            auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
            for (unsigned int j = 0; j < indices.size(0); j++)
              exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
    
    Franck Dary's avatar
    Franck Dary committed
            auto numer = exp_average_update_values.add_(old_exp_average_values);
            exp_average_sq_update_values.add_(old_exp_average_sq_values);
            auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
            const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
            auto divided = numer.div(denom);
            for (unsigned int j = 0; j < indices.size(0); j++)
              p.data()[indices[j].item<long>()] += -step_size*divided[j];
          }
          else
          {
            if (options.weight_decay() > 0)
            {
              NoGradGuard guard;
              p.grad() = p.grad() + options.weight_decay() * p;
            }
    
    Franck Dary's avatar
    Franck Dary committed
            exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
            exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
    
    Franck Dary's avatar
    Franck Dary committed
            Tensor denom;
            if (options.amsgrad())
            {
              auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
              max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
              denom = max_exp_average_sq / bias_correction2;
            }
            else
            {
              denom = exp_average_sq / bias_correction2;
            }
    
    Franck Dary's avatar
    Franck Dary committed
            const auto step_size = options.learning_rate() / bias_correction1;
    
    Franck Dary's avatar
    Franck Dary committed
            NoGradGuard guard;
            p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
          }
        }
      }
    
    Franck Dary's avatar
    Franck Dary committed
      void save(serialize::OutputArchive& archive) const override
    
    Franck Dary's avatar
    Franck Dary committed
        //serialize(*this, archive);
      }
    
    Franck Dary's avatar
    Franck Dary committed
      void load(serialize::InputArchive& archive) override
      {
        //serialize(*this, archive);
      }
    
    Franck Dary's avatar
    Franck Dary committed
      public :
    
    Franck Dary's avatar
    Franck Dary committed
      AdamOptions options;
    
    Franck Dary's avatar
    Franck Dary committed
      std::vector<int64_t> step_buffers;
      std::vector<Tensor> exp_average_buffers;
      std::vector<Tensor> exp_average_sq_buffers;
      std::vector<Tensor> max_exp_average_sq_buffers;
    
    Franck Dary's avatar
    Franck Dary committed
      private:
    
    Franck Dary's avatar
    Franck Dary committed
      SparseAdam() : options(0) {}
    
    Franck Dary's avatar
    Franck Dary committed
      template <typename Self, typename Archive>
      static void serialize(Self& self, Archive& archive)
      {
        _TORCH_OPTIM_SERIALIZE(step_buffers);
        _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
        _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
        _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    } // torch
    } // optim
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    constexpr int batchSize = 50;
    constexpr int nbExamples = 350000;
    constexpr int embeddingSize = 20;
    constexpr int nbClasses = 15;
    constexpr int nbWordsPerDatapoint = 5;
    constexpr int maxNbEmbeddings = 1000000;
    
    Franck Dary's avatar
    Franck Dary committed
    //3m15s
    struct NetworkImpl : torch::nn::Module
    {
      torch::nn::Linear linear{nullptr};
      torch::nn::Embedding wordEmbeddings{nullptr};
      NetworkImpl()
    
    Franck Dary's avatar
    Franck Dary committed
        linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
        wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
      };
      torch::Tensor forward(const torch::Tensor & input)
      {
        // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
        auto embeddingsOfInput = wordEmbeddings(input).mean(1);
        return torch::softmax(linear(embeddingsOfInput),1);
      }
    };
    TORCH_MODULE(Network);
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
    int main(int argc, char * argv[])
    {
      auto nn = Network();
      torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
      std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
      for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
        batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
    
    Franck Dary's avatar
    Franck Dary committed
    
    
    Franck Dary's avatar
    Franck Dary committed
      for (auto & batch : batches)
      {
        optimizer.zero_grad();
        auto prediction = nn(batch.first);
        auto loss = torch::nll_loss(torch::log(prediction), batch.second);
        loss.backward();
        optimizer.step();
    
    Franck Dary's avatar
    Franck Dary committed
      return 0;
    }
    
    
    Franck Dary's avatar
    Franck Dary committed
    //int main(int argc, char * argv[])
    //{
    //  if (argc != 5)
    //  {
    //    fmt::print(stderr, "needs 4 arguments.\n");
    //    exit(1);
    //  }
    //
    //  at::init_num_threads();
    //
    //  std::string machineFile = argv[1];
    //  std::string mcdFile = argv[2];
    //  std::string tsvFile = argv[3];
    //  //std::string rawFile = argv[4];
    //  std::string rawFile = "";
    //
    //  ReadingMachine machine(machineFile);
    //
    //  BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
    //  SubConfig config(goldConfig);
    //
    //  config.setState(machine.getStrategy().getInitialState());
    //
    //  std::vector<torch::Tensor> contexts;
    //  std::vector<torch::Tensor> classes;
    //
    //  fmt::print("Generating dataset...\n");
    //
    //  Dict dict(Dict::State::Open);
    //
    //  while (true)
    //  {
    //    auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
    //    if (!transition)
    //      util::myThrow("No transition appliable !");
    //
    //    auto context = config.extractContext(5,5,dict);
    //    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
    //
    //    int goldIndex = 3;
    //    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
    //
    //    classes.emplace_back(gold);
    //
    //    transition->apply(config);
    //    config.addToHistory(transition->getName());
    //
    //    auto movement = machine.getStrategy().getMovement(config, transition->getName());
    //    if (movement == Strategy::endMovement)
    //      break;
    //
    //    config.setState(movement.first);
    //    if (!config.moveWordIndex(movement.second))
    //      util::myThrow("Cannot move word index !");
    //
    //    if (config.needsUpdate())
    //      config.update();
    //  }
    //
    //  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
    //
    //  int nbExamples = *dataset.size();
    //  fmt::print("Done! size={}\n", nbExamples);
    //
    //  int batchSize = 100;
    //  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
    //
    //  TestNetwork nn(machine.getTransitionSet().size(), 5);
    //  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
    //
    //  for (int epoch = 1; epoch <= 1; ++epoch)
    //  {
    //    float totalLoss = 0.0;
    //    torch::Tensor example;
    //    int currentBatchNumber = 0;
    //
    //    for (auto & batch : *dataLoader)
    //    {
    //      optimizer.zero_grad();
    //
    //      auto data = batch.data;
    //      auto labels = batch.target.squeeze();
    //
    //      auto prediction = nn(data);
    //      example = prediction[0];
    //
    //      auto loss = torch::nll_loss(torch::log(prediction), labels);
    //      totalLoss += loss.item<float>();
    //      loss.backward();
    //      optimizer.step();
    //
    //      if (++currentBatchNumber*batchSize % 1000 == 0)
    //      {
    //        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
    //        std::fflush(stdout);
    //      }
    //    }
    //
    //    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
    //  }
    //
    //  return 0;
    //}
    //