Skip to content
Snippets Groups Projects
Commit d4ec0a24 authored by Franck Dary's avatar Franck Dary
Browse files

Made SparseAdam optimizer

parent be63334b
No related branches found
No related tags found
No related merge requests found
......@@ -8,105 +8,273 @@
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
int main(int argc, char * argv[])
namespace torch
{
if (argc != 5)
namespace optim
{
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
class SparseAdam : public Optimizer
{
public:
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
template <typename ParameterContainer>
explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options(options_)
{
}
config.setState(machine.getStrategy().getInitialState());
void step() override
{
for (size_t i = 0; i < parameters_.size(); ++i)
{
Tensor p = parameters_.at(i);
if (!p.grad().defined())
continue;
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
auto& exp_average = buffer_at(exp_average_buffers, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
fmt::print("Generating dataset...\n");
buffer_at(step_buffers, i) += 1;
const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
if (p.grad().is_sparse())
{
NoGradGuard guard;
p.grad() = p.grad().coalesce();
auto indices = p.grad().indices().squeeze();
auto values = p.grad().values();
Dict dict(Dict::State::Open);
auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average[indices[j].item<long>()] += exp_average_update_values[j];
auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
while (true)
auto numer = exp_average_update_values.add_(old_exp_average_values);
exp_average_sq_update_values.add_(old_exp_average_sq_values);
auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
auto divided = numer.div(denom);
for (unsigned int j = 0; j < indices.size(0); j++)
p.data()[indices[j].item<long>()] += -step_size*divided[j];
}
else
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
if (options.weight_decay() > 0)
{
NoGradGuard guard;
p.grad() = p.grad() + options.weight_decay() * p;
}
int goldIndex = 3;
auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
classes.emplace_back(gold);
Tensor denom;
if (options.amsgrad())
{
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
denom = max_exp_average_sq / bias_correction2;
}
else
{
denom = exp_average_sq / bias_correction2;
}
transition->apply(config);
config.addToHistory(transition->getName());
const auto step_size = options.learning_rate() / bias_correction1;
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
NoGradGuard guard;
p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
}
}
}
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
void save(serialize::OutputArchive& archive) const override
{
//serialize(*this, archive);
}
if (config.needsUpdate())
config.update();
void load(serialize::InputArchive& archive) override
{
//serialize(*this, archive);
}
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
public :
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
AdamOptions options;
int batchSize = 100;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
std::vector<int64_t> step_buffers;
std::vector<Tensor> exp_average_buffers;
std::vector<Tensor> exp_average_sq_buffers;
std::vector<Tensor> max_exp_average_sq_buffers;
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
private:
for (int epoch = 1; epoch <= 1; ++epoch)
{
float totalLoss = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
SparseAdam() : options(0) {}
for (auto & batch : *dataLoader)
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive)
{
optimizer.zero_grad();
_TORCH_OPTIM_SERIALIZE(step_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
_TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
auto data = batch.data;
auto labels = batch.target.squeeze();
} // torch
} // optim
auto prediction = nn(data);
example = prediction[0];
constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
constexpr int nbClasses = 15;
constexpr int nbWordsPerDatapoint = 5;
constexpr int maxNbEmbeddings = 1000000;
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
loss.backward();
optimizer.step();
if (++currentBatchNumber*batchSize % 1000 == 0)
//3m15s
struct NetworkImpl : torch::nn::Module
{
fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
std::fflush(stdout);
}
torch::nn::Linear linear{nullptr};
torch::nn::Embedding wordEmbeddings{nullptr};
NetworkImpl()
{
linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
};
torch::Tensor forward(const torch::Tensor & input)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto embeddingsOfInput = wordEmbeddings(input).mean(1);
return torch::softmax(linear(embeddingsOfInput),1);
}
};
TORCH_MODULE(Network);
fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
}
int main(int argc, char * argv[])
{
auto nn = Network();
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
for (auto & batch : batches)
{
optimizer.zero_grad();
auto prediction = nn(batch.first);
auto loss = torch::nll_loss(torch::log(prediction), batch.second);
loss.backward();
optimizer.step();
}
return 0;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment