Skip to content
Snippets Groups Projects
Commit d4ec0a24 authored by Franck Dary's avatar Franck Dary
Browse files

Made SparseAdam optimizer

parent be63334b
No related branches found
No related tags found
No related merge requests found
......@@ -8,105 +8,273 @@
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
int main(int argc, char * argv[])
namespace torch
{
namespace optim
{
if (argc != 5)
class SparseAdam : public Optimizer
{
public:
template <typename ParameterContainer>
explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_)
: Optimizer(std::forward<ParameterContainer>(parameters)),
options(options_)
{
fmt::print(stderr, "needs 4 arguments.\n");
exit(1);
}
std::string machineFile = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
void step() override
{
for (size_t i = 0; i < parameters_.size(); ++i)
{
Tensor p = parameters_.at(i);
if (!p.grad().defined())
continue;
auto& exp_average = buffer_at(exp_average_buffers, i);
auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i);
buffer_at(step_buffers, i) += 1;
const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i));
const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i));
if (p.grad().is_sparse())
{
NoGradGuard guard;
p.grad() = p.grad().coalesce();
auto indices = p.grad().indices().squeeze();
auto values = p.grad().values();
ReadingMachine machine(machineFile);
auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values();
auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average[indices[j].item<long>()] += exp_average_update_values[j];
auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values();
auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2());
for (unsigned int j = 0; j < indices.size(0); j++)
exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j];
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
auto numer = exp_average_update_values.add_(old_exp_average_values);
exp_average_sq_update_values.add_(old_exp_average_sq_values);
auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps());
const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1;
auto divided = numer.div(denom);
for (unsigned int j = 0; j < indices.size(0); j++)
p.data()[indices[j].item<long>()] += -step_size*divided[j];
}
else
{
if (options.weight_decay() > 0)
{
NoGradGuard guard;
p.grad() = p.grad() + options.weight_decay() * p;
}
config.setState(machine.getStrategy().getInitialState());
exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1());
exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
Tensor denom;
if (options.amsgrad())
{
auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i);
max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq);
denom = max_exp_average_sq / bias_correction2;
}
else
{
denom = exp_average_sq / bias_correction2;
}
fmt::print("Generating dataset...\n");
const auto step_size = options.learning_rate() / bias_correction1;
Dict dict(Dict::State::Open);
NoGradGuard guard;
p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size);
}
}
}
while (true)
void save(serialize::OutputArchive& archive) const override
{
auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!transition)
util::myThrow("No transition appliable !");
//serialize(*this, archive);
}
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
void load(serialize::InputArchive& archive) override
{
//serialize(*this, archive);
}
int goldIndex = 3;
auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
public :
classes.emplace_back(gold);
AdamOptions options;
transition->apply(config);
config.addToHistory(transition->getName());
std::vector<int64_t> step_buffers;
std::vector<Tensor> exp_average_buffers;
std::vector<Tensor> exp_average_sq_buffers;
std::vector<Tensor> max_exp_average_sq_buffers;
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
private:
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
SparseAdam() : options(0) {}
if (config.needsUpdate())
config.update();
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive)
{
_TORCH_OPTIM_SERIALIZE(step_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_buffers);
_TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
_TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
} // torch
} // optim
int nbExamples = *dataset.size();
fmt::print("Done! size={}\n", nbExamples);
constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
constexpr int nbClasses = 15;
constexpr int nbWordsPerDatapoint = 5;
constexpr int maxNbEmbeddings = 1000000;
int batchSize = 100;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
for (int epoch = 1; epoch <= 1; ++epoch)
//3m15s
struct NetworkImpl : torch::nn::Module
{
torch::nn::Linear linear{nullptr};
torch::nn::Embedding wordEmbeddings{nullptr};
NetworkImpl()
{
float totalLoss = 0.0;
torch::Tensor example;
int currentBatchNumber = 0;
for (auto & batch : *dataLoader)
{
optimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
loss.backward();
optimizer.step();
linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false)));
};
torch::Tensor forward(const torch::Tensor & input)
{
// I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
auto embeddingsOfInput = wordEmbeddings(input).mean(1);
return torch::softmax(linear(embeddingsOfInput),1);
}
};
TORCH_MODULE(Network);
if (++currentBatchNumber*batchSize % 1000 == 0)
{
fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
std::fflush(stdout);
}
}
int main(int argc, char * argv[])
{
auto nn = Network();
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1));
std::vector<std::pair<torch::Tensor,torch::Tensor>> batches;
for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong)));
fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
for (auto & batch : batches)
{
optimizer.zero_grad();
auto prediction = nn(batch.first);
auto loss = torch::nll_loss(torch::log(prediction), batch.second);
loss.backward();
optimizer.step();
}
return 0;
}
//int main(int argc, char * argv[])
//{
// if (argc != 5)
// {
// fmt::print(stderr, "needs 4 arguments.\n");
// exit(1);
// }
//
// at::init_num_threads();
//
// std::string machineFile = argv[1];
// std::string mcdFile = argv[2];
// std::string tsvFile = argv[3];
// //std::string rawFile = argv[4];
// std::string rawFile = "";
//
// ReadingMachine machine(machineFile);
//
// BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
// SubConfig config(goldConfig);
//
// config.setState(machine.getStrategy().getInitialState());
//
// std::vector<torch::Tensor> contexts;
// std::vector<torch::Tensor> classes;
//
// fmt::print("Generating dataset...\n");
//
// Dict dict(Dict::State::Open);
//
// while (true)
// {
// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config);
// if (!transition)
// util::myThrow("No transition appliable !");
//
// auto context = config.extractContext(5,5,dict);
// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
//
// int goldIndex = 3;
// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
//
// classes.emplace_back(gold);
//
// transition->apply(config);
// config.addToHistory(transition->getName());
//
// auto movement = machine.getStrategy().getMovement(config, transition->getName());
// if (movement == Strategy::endMovement)
// break;
//
// config.setState(movement.first);
// if (!config.moveWordIndex(movement.second))
// util::myThrow("Cannot move word index !");
//
// if (config.needsUpdate())
// config.update();
// }
//
// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
//
// int nbExamples = *dataset.size();
// fmt::print("Done! size={}\n", nbExamples);
//
// int batchSize = 100;
// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
//
// TestNetwork nn(machine.getTransitionSet().size(), 5);
// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
//
// for (int epoch = 1; epoch <= 1; ++epoch)
// {
// float totalLoss = 0.0;
// torch::Tensor example;
// int currentBatchNumber = 0;
//
// for (auto & batch : *dataLoader)
// {
// optimizer.zero_grad();
//
// auto data = batch.data;
// auto labels = batch.target.squeeze();
//
// auto prediction = nn(data);
// example = prediction[0];
//
// auto loss = torch::nll_loss(torch::log(prediction), labels);
// totalLoss += loss.item<float>();
// loss.backward();
// optimizer.step();
//
// if (++currentBatchNumber*batchSize % 1000 == 0)
// {
// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
// std::fflush(stdout);
// }
// }
//
// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
// }
//
// return 0;
//}
//
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment