From 92e9fda7f68cf9c743fefcb85357ed6a985eaf25 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Mon, 27 Jan 2020 16:51:50 +0100
Subject: [PATCH] SparseAdam moved to torch

---
 dev/src/dev.cpp | 134 +++---------------------------------------------
 1 file changed, 8 insertions(+), 126 deletions(-)

diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index 3f995d4..b98111b 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -8,127 +8,6 @@
 #include "TestNetwork.hpp"
 #include "ConfigDataset.hpp"
 
-namespace torch
-{
-namespace optim
-{
-
-class SparseAdam : public Optimizer
-{
-  public:
-
-  template <typename ParameterContainer>
-  explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
-      : Optimizer(std::forward<ParameterContainer>(parameters)),
-        options(options_)
-  {
-  }
-
-  void step() override
-  {
-    for (size_t i = 0; i < parameters_.size(); ++i)
-    {
-      Tensor p = parameters_.at(i);
-      if (!p.grad().defined())
-        continue;
-
-      auto& exp_average = buffer_at(exp_average_buffers, i);
-      auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
-  
-      buffer_at(step_buffers, i) += 1;
-      const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
-      const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
-      if (p.grad().is_sparse())
-      {
-        NoGradGuard guard;
-        p.grad() = p.grad().coalesce();
-        auto indices = p.grad().indices().squeeze();
-        auto values = p.grad().values();
-
-        auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
-        auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
-        for (unsigned int j = 0; j < indices.size(0); j++)
-          exp_average[indices[j].item<long>()] += exp_average_update_values[j];
-        auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
-        auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
-        for (unsigned int j = 0; j < indices.size(0); j++)
-          exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
-
-        auto numer = exp_average_update_values.add_(old_exp_average_values);
-        exp_average_sq_update_values.add_(old_exp_average_sq_values);
-        auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
-        const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
-        auto divided = numer.div(denom);
-        for (unsigned int j = 0; j < indices.size(0); j++)
-          p.data()[indices[j].item<long>()] += -step_size*divided[j];
-      }
-      else
-      {
-        if (options.weight_decay() > 0)
-        {
-          NoGradGuard guard;
-          p.grad() = p.grad() + options.weight_decay() * p;
-        }
-
-        exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
-        exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
-
-        Tensor denom;
-        if (options.amsgrad())
-        {
-          auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
-          max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
-          denom = max_exp_average_sq / bias_correction2;
-        }
-        else
-        {
-          denom = exp_average_sq / bias_correction2;
-        }
-
-        const auto step_size = options.learning_rate() / bias_correction1;
-
-        NoGradGuard guard;
-        p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
-      }
-    }
-  }
-
-  void save(serialize::OutputArchive& archive) const override
-  {
-    //serialize(*this, archive);
-  }
-
-  void load(serialize::InputArchive& archive) override
-  {
-    //serialize(*this, archive);
-  }
-
-  public :
-
-  AdamOptions options;
-
-  std::vector<int64_t> step_buffers;
-  std::vector<Tensor> exp_average_buffers;
-  std::vector<Tensor> exp_average_sq_buffers;
-  std::vector<Tensor> max_exp_average_sq_buffers;
-
-  private:
-
-  SparseAdam() : options(0) {}
-
-  template <typename Self, typename Archive>
-  static void serialize(Self& self, Archive& archive)
-  {
-    _TORCH_OPTIM_SERIALIZE(step_buffers);
-    _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
-    _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
-    _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
-  }
-};
-
-} // torch
-} // optim
-
 constexpr int batchSize = 50;
 constexpr int nbExamples = 350000;
 constexpr int embeddingSize = 20;
@@ -143,8 +22,8 @@ struct NetworkImpl : torch::nn::Module
   torch::nn::Embedding wordEmbeddings{nullptr};
   NetworkImpl()
   {
-    linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
-    wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
+    linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
+    wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
   };
   torch::Tensor forward(const torch::Tensor & input)
   {
@@ -158,18 +37,21 @@ TORCH_MODULE(Network);
 int main(int argc, char * argv[])
 {
   auto nn = Network();
-  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
+  torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
+  torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
   std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
   for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
     batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
 
   for (auto & batch : batches)
   {
-    optimizer.zero_grad();
+    sparseOptimizer.zero_grad();
+    denseOptimizer.zero_grad();
     auto prediction = nn(batch.first);
     auto loss = torch::nll_loss(torch::log(prediction), batch.second);
     loss.backward();
-    optimizer.step();
+    sparseOptimizer.step();
+    denseOptimizer.step();
   }
   return 0;
 }
-- 
GitLab