Skip to content
Snippets Groups Projects
Commit 92e9fda7 authored by Franck Dary's avatar Franck Dary
Browse files

SparseAdam moved to torch

parent d4ec0a24
No related branches found
No related tags found
No related merge requests found
......@@ -8,127 +8,6 @@
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
namespace torch
{
namespace optim
{
class SparseAdam : public Optimizer
{
public:
template <typename ParameterContainer>
explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options(options_)
{
}
void step() override
{
for (size_t i = 0; i < parameters_.size(); ++i)
{
Tensor p = parameters_.at(i);
if (!p.grad().defined())
continue;
auto& exp_average = buffer_at(exp_average_buffers, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
buffer_at(step_buffers, i) += 1;
const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
if (p.grad().is_sparse())
{
NoGradGuard guard;
p.grad() = p.grad().coalesce();
auto indices = p.grad().indices().squeeze();
auto values = p.grad().values();
auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average[indices[j].item<long>()] += exp_average_update_values[j];
auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
auto numer = exp_average_update_values.add_(old_exp_average_values);
exp_average_sq_update_values.add_(old_exp_average_sq_values);
auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
auto divided = numer.div(denom);
for (unsigned int j = 0; j < indices.size(0); j++)
p.data()[indices[j].item<long>()] += -step_size*divided[j];
}
else
{
if (options.weight_decay() > 0)
{
NoGradGuard guard;
p.grad() = p.grad() + options.weight_decay() * p;
}
exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
Tensor denom;
if (options.amsgrad())
{
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
denom = max_exp_average_sq / bias_correction2;
}
else
{
denom = exp_average_sq / bias_correction2;
}
const auto step_size = options.learning_rate() / bias_correction1;
NoGradGuard guard;
p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
}
}
}
void save(serialize::OutputArchive& archive) const override
{
//serialize(*this, archive);
}
void load(serialize::InputArchive& archive) override
{
//serialize(*this, archive);
}
public :
AdamOptions options;
std::vector<int64_t> step_buffers;
std::vector<Tensor> exp_average_buffers;
std::vector<Tensor> exp_average_sq_buffers;
std::vector<Tensor> max_exp_average_sq_buffers;
private:
SparseAdam() : options(0) {}
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive)
{
_TORCH_OPTIM_SERIALIZE(step_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
_TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
} // torch
} // optim
constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
......@@ -143,8 +22,8 @@ struct NetworkImpl : torch::nn::Module
torch::nn::Embedding wordEmbeddings{nullptr};
NetworkImpl()
{
linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
};
torch::Tensor forward(const torch::Tensor & input)
{
......@@ -158,18 +37,21 @@ TORCH_MODULE(Network);
int main(int argc, char * argv[])
{
auto nn = Network();
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
for (auto & batch : batches)
{
optimizer.zero_grad();
sparseOptimizer.zero_grad();
denseOptimizer.zero_grad();
auto prediction = nn(batch.first);
auto loss = torch::nll_loss(torch::log(prediction), batch.second);
loss.backward();
optimizer.step();
sparseOptimizer.step();
denseOptimizer.step();
}
return 0;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment