From 92e9fda7f68cf9c743fefcb85357ed6a985eaf25 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Mon, 27 Jan 2020 16:51:50 +0100 Subject: [PATCH] SparseAdam moved to torch --- dev/src/dev.cpp | 134 +++--------------------------------------------- 1 file changed, 8 insertions(+), 126 deletions(-) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 3f995d4..b98111b 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -8,127 +8,6 @@ #include "TestNetwork.hpp" #include "ConfigDataset.hpp" -namespace torch -{ -namespace optim -{ - -class SparseAdam : public Optimizer -{ - public: - - template <typename ParameterContainer> - explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_) - : Optimizer(std::forward<ParameterContainer>(parameters)), - options(options_) - { - } - - void step() override - { - for (size_t i = 0; i < parameters_.size(); ++i) - { - Tensor p = parameters_.at(i); - if (!p.grad().defined()) - continue; - - auto& exp_average = buffer_at(exp_average_buffers, i); - auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i); - - buffer_at(step_buffers, i) += 1; - const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i)); - const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i)); - if (p.grad().is_sparse()) - { - NoGradGuard guard; - p.grad() = p.grad().coalesce(); - auto indices = p.grad().indices().squeeze(); - auto values = p.grad().values(); - - auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values(); - auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1()); - for (unsigned int j = 0; j < indices.size(0); j++) - exp_average[indices[j].item<long>()] += exp_average_update_values[j]; - auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values(); - auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2()); - for (unsigned int j = 0; j < indices.size(0); j++) - exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j]; - - auto numer = exp_average_update_values.add_(old_exp_average_values); - exp_average_sq_update_values.add_(old_exp_average_sq_values); - auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps()); - const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1; - auto divided = numer.div(denom); - for (unsigned int j = 0; j < indices.size(0); j++) - p.data()[indices[j].item<long>()] += -step_size*divided[j]; - } - else - { - if (options.weight_decay() > 0) - { - NoGradGuard guard; - p.grad() = p.grad() + options.weight_decay() * p; - } - - exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1()); - exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2()); - - Tensor denom; - if (options.amsgrad()) - { - auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i); - max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq); - denom = max_exp_average_sq / bias_correction2; - } - else - { - denom = exp_average_sq / bias_correction2; - } - - const auto step_size = options.learning_rate() / bias_correction1; - - NoGradGuard guard; - p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size); - } - } - } - - void save(serialize::OutputArchive& archive) const override - { - //serialize(*this, archive); - } - - void load(serialize::InputArchive& archive) override - { - //serialize(*this, archive); - } - - public : - - AdamOptions options; - - std::vector<int64_t> step_buffers; - std::vector<Tensor> exp_average_buffers; - std::vector<Tensor> exp_average_sq_buffers; - std::vector<Tensor> max_exp_average_sq_buffers; - - private: - - SparseAdam() : options(0) {} - - template <typename Self, typename Archive> - static void serialize(Self& self, Archive& archive) - { - _TORCH_OPTIM_SERIALIZE(step_buffers); - _TORCH_OPTIM_SERIALIZE(exp_average_buffers); - _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); - _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); - } -}; - -} // torch -} // optim - constexpr int batchSize = 50; constexpr int nbExamples = 350000; constexpr int embeddingSize = 20; @@ -143,8 +22,8 @@ struct NetworkImpl : torch::nn::Module torch::nn::Embedding wordEmbeddings{nullptr}; NetworkImpl() { - linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses)); - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false))); + linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); + wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); }; torch::Tensor forward(const torch::Tensor & input) { @@ -158,18 +37,21 @@ TORCH_MODULE(Network); int main(int argc, char * argv[]) { auto nn = Network(); - torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1)); + torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); + torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); for (auto & batch : batches) { - optimizer.zero_grad(); + sparseOptimizer.zero_grad(); + denseOptimizer.zero_grad(); auto prediction = nn(batch.first); auto loss = torch::nll_loss(torch::log(prediction), batch.second); loss.backward(); - optimizer.step(); + sparseOptimizer.step(); + denseOptimizer.step(); } return 0; } -- GitLab