#include <cstdio> #include "fmt/core.h" #include "util.hpp" #include "BaseConfig.hpp" #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "TestNetwork.hpp" #include "ConfigDataset.hpp" namespace torch { namespace optim { class SparseAdam : public Optimizer { public: template <typename ParameterContainer> explicit SparseAdam(ParameterContainer&& parameters, const AdamOptions& options_) : Optimizer(std::forward<ParameterContainer>(parameters)), options(options_) { } void step() override { for (size_t i = 0; i < parameters_.size(); ++i) { Tensor p = parameters_.at(i); if (!p.grad().defined()) continue; auto& exp_average = buffer_at(exp_average_buffers, i); auto& exp_average_sq = buffer_at(exp_average_sq_buffers, i); buffer_at(step_buffers, i) += 1; const auto bias_correction1 = 1 - std::pow(options.beta1(), buffer_at(step_buffers, i)); const auto bias_correction2 = 1 - std::pow(options.beta2(), buffer_at(step_buffers, i)); if (p.grad().is_sparse()) { NoGradGuard guard; p.grad() = p.grad().coalesce(); auto indices = p.grad().indices().squeeze(); auto values = p.grad().values(); auto old_exp_average_values = exp_average.sparse_mask(p.grad())._values(); auto exp_average_update_values = values.sub(old_exp_average_values).mul_(1 - options.beta1()); for (unsigned int j = 0; j < indices.size(0); j++) exp_average[indices[j].item<long>()] += exp_average_update_values[j]; auto old_exp_average_sq_values = exp_average_sq.sparse_mask(p.grad())._values(); auto exp_average_sq_update_values = values.pow(2).sub_(old_exp_average_sq_values).mul_(1 - options.beta2()); for (unsigned int j = 0; j < indices.size(0); j++) exp_average_sq[indices[j].item<long>()] += exp_average_sq_update_values[j]; auto numer = exp_average_update_values.add_(old_exp_average_values); exp_average_sq_update_values.add_(old_exp_average_sq_values); auto denom = exp_average_sq_update_values.sqrt_().add_(options.eps()); const auto step_size = options.learning_rate() * std::sqrt(bias_correction2) / bias_correction1; auto divided = numer.div(denom); for (unsigned int j = 0; j < indices.size(0); j++) p.data()[indices[j].item<long>()] += -step_size*divided[j]; } else { if (options.weight_decay() > 0) { NoGradGuard guard; p.grad() = p.grad() + options.weight_decay() * p; } exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1()); exp_average_sq.mul_(options.beta2()).addcmul_(p.grad(), p.grad(), 1 - options.beta2()); Tensor denom; if (options.amsgrad()) { auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i); max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq); denom = max_exp_average_sq / bias_correction2; } else { denom = exp_average_sq / bias_correction2; } const auto step_size = options.learning_rate() / bias_correction1; NoGradGuard guard; p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size); } } } void save(serialize::OutputArchive& archive) const override { //serialize(*this, archive); } void load(serialize::InputArchive& archive) override { //serialize(*this, archive); } public : AdamOptions options; std::vector<int64_t> step_buffers; std::vector<Tensor> exp_average_buffers; std::vector<Tensor> exp_average_sq_buffers; std::vector<Tensor> max_exp_average_sq_buffers; private: SparseAdam() : options(0) {} template <typename Self, typename Archive> static void serialize(Self& self, Archive& archive) { _TORCH_OPTIM_SERIALIZE(step_buffers); _TORCH_OPTIM_SERIALIZE(exp_average_buffers); _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); } }; } // torch } // optim constexpr int batchSize = 50; constexpr int nbExamples = 350000; constexpr int embeddingSize = 20; constexpr int nbClasses = 15; constexpr int nbWordsPerDatapoint = 5; constexpr int maxNbEmbeddings = 1000000; //3m15s struct NetworkImpl : torch::nn::Module { torch::nn::Linear linear{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr}; NetworkImpl() { linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses)); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(false))); }; torch::Tensor forward(const torch::Tensor & input) { // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words auto embeddingsOfInput = wordEmbeddings(input).mean(1); return torch::softmax(linear(embeddingsOfInput),1); } }; TORCH_MODULE(Network); int main(int argc, char * argv[]) { auto nn = Network(); torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5).weight_decay(0.1)); std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); for (auto & batch : batches) { optimizer.zero_grad(); auto prediction = nn(batch.first); auto loss = torch::nll_loss(torch::log(prediction), batch.second); loss.backward(); optimizer.step(); } return 0; } //int main(int argc, char * argv[]) //{ // if (argc != 5) // { // fmt::print(stderr, "needs 4 arguments.\n"); // exit(1); // } // // at::init_num_threads(); // // std::string machineFile = argv[1]; // std::string mcdFile = argv[2]; // std::string tsvFile = argv[3]; // //std::string rawFile = argv[4]; // std::string rawFile = ""; // // ReadingMachine machine(machineFile); // // BaseConfig goldConfig(mcdFile, tsvFile, rawFile); // SubConfig config(goldConfig); // // config.setState(machine.getStrategy().getInitialState()); // // std::vector<torch::Tensor> contexts; // std::vector<torch::Tensor> classes; // // fmt::print("Generating dataset...\n"); // // Dict dict(Dict::State::Open); // // while (true) // { // auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); // if (!transition) // util::myThrow("No transition appliable !"); // // auto context = config.extractContext(5,5,dict); // contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); // // int goldIndex = 3; // auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); // // classes.emplace_back(gold); // // transition->apply(config); // config.addToHistory(transition->getName()); // // auto movement = machine.getStrategy().getMovement(config, transition->getName()); // if (movement == Strategy::endMovement) // break; // // config.setState(movement.first); // if (!config.moveWordIndex(movement.second)) // util::myThrow("Cannot move word index !"); // // if (config.needsUpdate()) // config.update(); // } // // auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); // // int nbExamples = *dataset.size(); // fmt::print("Done! size={}\n", nbExamples); // // int batchSize = 100; // auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); // // TestNetwork nn(machine.getTransitionSet().size(), 5); // torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); // // for (int epoch = 1; epoch <= 1; ++epoch) // { // float totalLoss = 0.0; // torch::Tensor example; // int currentBatchNumber = 0; // // for (auto & batch : *dataLoader) // { // optimizer.zero_grad(); // // auto data = batch.data; // auto labels = batch.target.squeeze(); // // auto prediction = nn(data); // example = prediction[0]; // // auto loss = torch::nll_loss(torch::log(prediction), labels); // totalLoss += loss.item<float>(); // loss.backward(); // optimizer.step(); // // if (++currentBatchNumber*batchSize % 1000 == 0) // { // fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); // std::fflush(stdout); // } // } // // fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); // } // // return 0; //} //