#include <cstdio> #include "fmt/core.h" #include "util.hpp" #include "BaseConfig.hpp" #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "TestNetwork.hpp" #include "ConfigDataset.hpp" //constexpr int batchSize = 50; //constexpr int nbExamples = 350000; //constexpr int embeddingSize = 20; //constexpr int nbClasses = 15; //constexpr int nbWordsPerDatapoint = 5; //constexpr int maxNbEmbeddings = 1000000; // //struct NetworkImpl : torch::nn::Module //{ // torch::nn::Linear linear{nullptr}; // torch::nn::Embedding wordEmbeddings{nullptr}; // // std::vector<torch::Tensor> _sparseParameters; // std::vector<torch::Tensor> _denseParameters; // NetworkImpl() // { // linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); // auto params = linear->parameters(); // _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); // // wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); // params = wordEmbeddings->parameters(); // _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); // }; // const std::vector<torch::Tensor> & denseParameters() // { // return _denseParameters; // } // const std::vector<torch::Tensor> & sparseParameters() // { // return _sparseParameters; // } // torch::Tensor forward(const torch::Tensor & input) // { // // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words // auto embeddingsOfInput = wordEmbeddings(input).mean(1); // return torch::softmax(linear(embeddingsOfInput),1); // } //}; //TORCH_MODULE(Network); //int main(int argc, char * argv[]) //{ // auto nn = Network(); // torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); // torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); // std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; // for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) // batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); // // for (auto & batch : batches) // { // sparseOptimizer.zero_grad(); // denseOptimizer.zero_grad(); // auto prediction = nn(batch.first); // auto loss = torch::nll_loss(torch::log(prediction), batch.second); // loss.backward(); // sparseOptimizer.step(); // denseOptimizer.step(); // } // return 0; //} int main(int argc, char * argv[]) { if (argc != 5) { fmt::print(stderr, "needs 4 arguments.\n"); exit(1); } at::init_num_threads(); std::string machineFile = argv[1]; std::string mcdFile = argv[2]; std::string tsvFile = argv[3]; //std::string rawFile = argv[4]; std::string rawFile = ""; ReadingMachine machine(machineFile); BaseConfig goldConfig(mcdFile, tsvFile, rawFile); SubConfig config(goldConfig); config.setState(machine.getStrategy().getInitialState()); std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> classes; fmt::print("Generating dataset...\n"); Dict dict(Dict::State::Open); while (true) { auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); if (!transition) util::myThrow("No transition appliable !"); auto context = config.extractContext(5,5,dict); contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); int goldIndex = 3; auto gold = torch::zeros(1, at::kLong); gold[0] = goldIndex; classes.emplace_back(gold); transition->apply(config); config.addToHistory(transition->getName()); auto movement = machine.getStrategy().getMovement(config, transition->getName()); if (movement == Strategy::endMovement) break; config.setState(movement.first); if (!config.moveWordIndex(movement.second)) util::myThrow("Cannot move word index !"); if (config.needsUpdate()) config.update(); } auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); int nbExamples = *dataset.size(); fmt::print("Done! size={}\n", nbExamples); int batchSize = 100; auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5)); torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5)); for (int epoch = 1; epoch <= 2; ++epoch) { float totalLoss = 0.0; float lossSoFar = 0.0; torch::Tensor example; int currentBatchNumber = 0; for (auto & batch : *dataLoader) { denseOptimizer.zero_grad(); sparseOptimizer.zero_grad(); auto data = batch.data; auto labels = batch.target.squeeze(); auto prediction = nn(data); example = prediction[0]; auto loss = torch::nll_loss(torch::log(prediction), labels); totalLoss += loss.item<float>(); lossSoFar += loss.item<float>(); loss.backward(); denseOptimizer.step(); sparseOptimizer.step(); if (++currentBatchNumber*batchSize % 1000 == 0) { fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar); std::fflush(stdout); lossSoFar = 0; } } fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss); } return 0; }