diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index 2aac98cf5171f86e9bbd93d953390d246e21a0e7..3f995d40660407305e0c92bcf1a94e72d5325934 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -8,105 +8,273 @@ #include "TestNetwork.hpp" #include "ConfigDataset.hpp" -int main(int argc, char * argv[]) +namespace torch +{ +namespace optim { - if (argc != 5) + +class SparseAdam : public Optimizer +{ + public: + + template <typename ParameterContainer> + explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_) + : Optimizer(std::forward<ParameterContainer>(parameters)), + options(options_) { - fmt::print(stderr, "needs 4 arguments.\n"); - exit(1); } - std::string machineFile = argv[1]; - std::string mcdFile = argv[2]; - std::string tsvFile = argv[3]; - //std::string rawFile = argv[4]; - std::string rawFile = ""; + void step() override + { + for (size_t i = 0; i < parameters_.size(); ++i) + { + Tensor p = parameters_.at(i); + if (!p.grad().defined()) + continue; + + auto& exp_average = buffer_at(exp_average_buffers, i); + auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i); + + buffer_at(step_buffers, i) += 1; + const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i)); + const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i)); + if (p.grad().is_sparse()) + { + NoGradGuard guard; + p.grad() = p.grad().coalesce(); + auto indices = p.grad().indices().squeeze(); + auto values = p.grad().values(); - ReadingMachine machine(machineFile); + auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values(); + auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1()); + for (unsigned int j = 0; j < indices.size(0); j++) + exp_average[indices[j].item<long>()] += exp_average_update_values[j]; + auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values(); + auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2()); + for (unsigned int j = 0; j < indices.size(0); j++) + exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j]; - BaseConfig goldConfig(mcdFile, tsvFile, rawFile); - SubConfig config(goldConfig); + auto numer = exp_average_update_values.add_(old_exp_average_values); + exp_average_sq_update_values.add_(old_exp_average_sq_values); + auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps()); + const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1; + auto divided = numer.div(denom); + for (unsigned int j = 0; j < indices.size(0); j++) + p.data()[indices[j].item<long>()] += -step_size*divided[j]; + } + else + { + if (options.weight_decay() > 0) + { + NoGradGuard guard; + p.grad() = p.grad() + options.weight_decay() * p; + } - config.setState(machine.getStrategy().getInitialState()); + exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1()); + exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2()); - std::vector<torch::Tensor> contexts; - std::vector<torch::Tensor> classes; + Tensor denom; + if (options.amsgrad()) + { + auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i); + max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq); + denom = max_exp_average_sq / bias_correction2; + } + else + { + denom = exp_average_sq / bias_correction2; + } - fmt::print("Generating dataset...\n"); + const auto step_size = options.learning_rate() / bias_correction1; - Dict dict(Dict::State::Open); + NoGradGuard guard; + p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size); + } + } + } - while (true) + void save(serialize::OutputArchive& archive) const override { - auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); - if (!transition) - util::myThrow("No transition appliable !"); + //serialize(*this, archive); + } - auto context = config.extractContext(5,5,dict); - contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + void load(serialize::InputArchive& archive) override + { + //serialize(*this, archive); + } - int goldIndex = 3; - auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); + public : - classes.emplace_back(gold); + AdamOptions options; - transition->apply(config); - config.addToHistory(transition->getName()); + std::vector<int64_t> step_buffers; + std::vector<Tensor> exp_average_buffers; + std::vector<Tensor> exp_average_sq_buffers; + std::vector<Tensor> max_exp_average_sq_buffers; - auto movement = machine.getStrategy().getMovement(config, transition->getName()); - if (movement == Strategy::endMovement) - break; + private: - config.setState(movement.first); - if (!config.moveWordIndex(movement.second)) - util::myThrow("Cannot move word index !"); + SparseAdam() : options(0) {} - if (config.needsUpdate()) - config.update(); + template <typename Self, typename Archive> + static void serialize(Self& self, Archive& archive) + { + _TORCH_OPTIM_SERIALIZE(step_buffers); + _TORCH_OPTIM_SERIALIZE(exp_average_buffers); + _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); + _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); } +}; - auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); +} // torch +} // optim - int nbExamples = *dataset.size(); - fmt::print("Done! size={}\n", nbExamples); +constexpr int batchSize = 50; +constexpr int nbExamples = 350000; +constexpr int embeddingSize = 20; +constexpr int nbClasses = 15; +constexpr int nbWordsPerDatapoint = 5; +constexpr int maxNbEmbeddings = 1000000; - int batchSize = 100; - auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - - TestNetwork nn(machine.getTransitionSet().size(), 5); - torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - - for (int epoch = 1; epoch <= 1; ++epoch) +//3m15s +struct NetworkImpl : torch::nn::Module +{ + torch::nn::Linear linear{nullptr}; + torch::nn::Embedding wordEmbeddings{nullptr}; + NetworkImpl() { - float totalLoss = 0.0; - torch::Tensor example; - int currentBatchNumber = 0; - - for (auto & batch : *dataLoader) - { - optimizer.zero_grad(); - - auto data = batch.data; - auto labels = batch.target.squeeze(); - - auto prediction = nn(data); - example = prediction[0]; - - auto loss = torch::nll_loss(torch::log(prediction), labels); - totalLoss += loss.item<float>(); - loss.backward(); - optimizer.step(); + linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses)); + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false))); + }; + torch::Tensor forward(const torch::Tensor & input) + { + // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words + auto embeddingsOfInput = wordEmbeddings(input).mean(1); + return torch::softmax(linear(embeddingsOfInput),1); + } +}; +TORCH_MODULE(Network); - if (++currentBatchNumber*batchSize % 1000 == 0) - { - fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); - std::fflush(stdout); - } - } +int main(int argc, char * argv[]) +{ + auto nn = Network(); + torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1)); + std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; + for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) + batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); - fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); + for (auto & batch : batches) + { + optimizer.zero_grad(); + auto prediction = nn(batch.first); + auto loss = torch::nll_loss(torch::log(prediction), batch.second); + loss.backward(); + optimizer.step(); } - return 0; } +//int main(int argc, char * argv[]) +//{ +// if (argc != 5) +// { +// fmt::print(stderr, "needs 4 arguments.\n"); +// exit(1); +// } +// +// at::init_num_threads(); +// +// std::string machineFile = argv[1]; +// std::string mcdFile = argv[2]; +// std::string tsvFile = argv[3]; +// //std::string rawFile = argv[4]; +// std::string rawFile = ""; +// +// ReadingMachine machine(machineFile); +// +// BaseConfig goldConfig(mcdFile, tsvFile, rawFile); +// SubConfig config(goldConfig); +// +// config.setState(machine.getStrategy().getInitialState()); +// +// std::vector<torch::Tensor> contexts; +// std::vector<torch::Tensor> classes; +// +// fmt::print("Generating dataset...\n"); +// +// Dict dict(Dict::State::Open); +// +// while (true) +// { +// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); +// if (!transition) +// util::myThrow("No transition appliable !"); +// +// auto context = config.extractContext(5,5,dict); +// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); +// +// int goldIndex = 3; +// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); +// +// classes.emplace_back(gold); +// +// transition->apply(config); +// config.addToHistory(transition->getName()); +// +// auto movement = machine.getStrategy().getMovement(config, transition->getName()); +// if (movement == Strategy::endMovement) +// break; +// +// config.setState(movement.first); +// if (!config.moveWordIndex(movement.second)) +// util::myThrow("Cannot move word index !"); +// +// if (config.needsUpdate()) +// config.update(); +// } +// +// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); +// +// int nbExamples = *dataset.size(); +// fmt::print("Done! size={}\n", nbExamples); +// +// int batchSize = 100; +// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); +// +// TestNetwork nn(machine.getTransitionSet().size(), 5); +// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); +// +// for (int epoch = 1; epoch <= 1; ++epoch) +// { +// float totalLoss = 0.0; +// torch::Tensor example; +// int currentBatchNumber = 0; +// +// for (auto & batch : *dataLoader) +// { +// optimizer.zero_grad(); +// +// auto data = batch.data; +// auto labels = batch.target.squeeze(); +// +// auto prediction = nn(data); +// example = prediction[0]; +// +// auto loss = torch::nll_loss(torch::log(prediction), labels); +// totalLoss += loss.item<float>(); +// loss.backward(); +// optimizer.step(); +// +// if (++currentBatchNumber*batchSize % 1000 == 0) +// { +// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); +// std::fflush(stdout); +// } +// } +// +// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); +// } +// +// return 0; +//} +//