Skip to content
Snippets Groups Projects
Commit 92e9fda7 authored by Franck Dary's avatar Franck Dary
Browse files

SparseAdam moved to torch

parent d4ec0a24
Branches
No related tags found
No related merge requests found
...@@ -8,127 +8,6 @@ ...@@ -8,127 +8,6 @@
#include "TestNetwork.hpp" #include "TestNetwork.hpp"
#include "ConfigDataset.hpp" #include "ConfigDataset.hpp"
namespace torch
{
namespace optim
{
class SparseAdam : public Optimizer
{
public:
template <typename ParameterContainer>
explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options(options_)
{
}
void step() override
{
for (size_t i = 0; i < parameters_.size(); ++i)
{
Tensor p = parameters_.at(i);
if (!p.grad().defined())
continue;
auto& exp_average = buffer_at(exp_average_buffers, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
buffer_at(step_buffers, i) += 1;
const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
if (p.grad().is_sparse())
{
NoGradGuard guard;
p.grad() = p.grad().coalesce();
auto indices = p.grad().indices().squeeze();
auto values = p.grad().values();
auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average[indices[j].item<long>()] += exp_average_update_values[j];
auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
auto numer = exp_average_update_values.add_(old_exp_average_values);
exp_average_sq_update_values.add_(old_exp_average_sq_values);
auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
auto divided = numer.div(denom);
for (unsigned int j = 0; j < indices.size(0); j++)
p.data()[indices[j].item<long>()] += -step_size*divided[j];
}
else
{
if (options.weight_decay() > 0)
{
NoGradGuard guard;
p.grad() = p.grad() + options.weight_decay() * p;
}
exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
Tensor denom;
if (options.amsgrad())
{
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
denom = max_exp_average_sq / bias_correction2;
}
else
{
denom = exp_average_sq / bias_correction2;
}
const auto step_size = options.learning_rate() / bias_correction1;
NoGradGuard guard;
p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
}
}
}
void save(serialize::OutputArchive& archive) const override
{
//serialize(*this, archive);
}
void load(serialize::InputArchive& archive) override
{
//serialize(*this, archive);
}
public :
AdamOptions options;
std::vector<int64_t> step_buffers;
std::vector<Tensor> exp_average_buffers;
std::vector<Tensor> exp_average_sq_buffers;
std::vector<Tensor> max_exp_average_sq_buffers;
private:
SparseAdam() : options(0) {}
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive)
{
_TORCH_OPTIM_SERIALIZE(step_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
_TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
} // torch
} // optim
constexpr int batchSize = 50; constexpr int batchSize = 50;
constexpr int nbExamples = 350000; constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20; constexpr int embeddingSize = 20;
...@@ -143,8 +22,8 @@ struct NetworkImpl : torch::nn::Module ...@@ -143,8 +22,8 @@ struct NetworkImpl : torch::nn::Module
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
NetworkImpl() NetworkImpl()
{ {
linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses)); linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false))); wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
}; };
torch::Tensor forward(const torch::Tensor & input) torch::Tensor forward(const torch::Tensor & input)
{ {
...@@ -158,18 +37,21 @@ TORCH_MODULE(Network); ...@@ -158,18 +37,21 @@ TORCH_MODULE(Network);
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
auto nn = Network(); auto nn = Network();
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1)); torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5));
torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
for (auto & batch : batches) for (auto & batch : batches)
{ {
optimizer.zero_grad(); sparseOptimizer.zero_grad();
denseOptimizer.zero_grad();
auto prediction = nn(batch.first); auto prediction = nn(batch.first);
auto loss = torch::nll_loss(torch::log(prediction), batch.second); auto loss = torch::nll_loss(torch::log(prediction), batch.second);
loss.backward(); loss.backward();
optimizer.step(); sparseOptimizer.step();
denseOptimizer.step();
} }
return 0; return 0;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment