#include <cstdio> #include "fmt/core.h" #include "util.hpp" #include "BaseConfig.hpp" #include "SubConfig.hpp" #include "TransitionSet.hpp" #include "ReadingMachine.hpp" #include "TestNetwork.hpp" #include "ConfigDataset.hpp" constexpr int batchSize = 50; constexpr int nbExamples = 350000; constexpr int embeddingSize = 20; constexpr int nbClasses = 15; constexpr int nbWordsPerDatapoint = 5; constexpr int maxNbEmbeddings = 1000000; //3m15s struct NetworkImpl : torch::nn::Module { torch::nn::Linear linear{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr}; NetworkImpl() { linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); }; torch::Tensor forward(const torch::Tensor & input) { // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words auto embeddingsOfInput = wordEmbeddings(input).mean(1); return torch::softmax(linear(embeddingsOfInput),1); } }; TORCH_MODULE(Network); int main(int argc, char * argv[]) { auto nn = Network(); torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); for (auto & batch : batches) { sparseOptimizer.zero_grad(); denseOptimizer.zero_grad(); auto prediction = nn(batch.first); auto loss = torch::nll_loss(torch::log(prediction), batch.second); loss.backward(); sparseOptimizer.step(); denseOptimizer.step(); } return 0; } //int main(int argc, char * argv[]) //{ // if (argc != 5) // { // fmt::print(stderr, "needs 4 arguments.\n"); // exit(1); // } // // at::init_num_threads(); // // std::string machineFile = argv[1]; // std::string mcdFile = argv[2]; // std::string tsvFile = argv[3]; // //std::string rawFile = argv[4]; // std::string rawFile = ""; // // ReadingMachine machine(machineFile); // // BaseConfig goldConfig(mcdFile, tsvFile, rawFile); // SubConfig config(goldConfig); // // config.setState(machine.getStrategy().getInitialState()); // // std::vector<torch::Tensor> contexts; // std::vector<torch::Tensor> classes; // // fmt::print("Generating dataset...\n"); // // Dict dict(Dict::State::Open); // // while (true) // { // auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); // if (!transition) // util::myThrow("No transition appliable !"); // // auto context = config.extractContext(5,5,dict); // contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); // // int goldIndex = 3; // auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); // // classes.emplace_back(gold); // // transition->apply(config); // config.addToHistory(transition->getName()); // // auto movement = machine.getStrategy().getMovement(config, transition->getName()); // if (movement == Strategy::endMovement) // break; // // config.setState(movement.first); // if (!config.moveWordIndex(movement.second)) // util::myThrow("Cannot move word index !"); // // if (config.needsUpdate()) // config.update(); // } // // auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); // // int nbExamples = *dataset.size(); // fmt::print("Done! size={}\n", nbExamples); // // int batchSize = 100; // auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); // // TestNetwork nn(machine.getTransitionSet().size(), 5); // torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); // // for (int epoch = 1; epoch <= 1; ++epoch) // { // float totalLoss = 0.0; // torch::Tensor example; // int currentBatchNumber = 0; // // for (auto & batch : *dataLoader) // { // optimizer.zero_grad(); // // auto data = batch.data; // auto labels = batch.target.squeeze(); // // auto prediction = nn(data); // example = prediction[0]; // // auto loss = torch::nll_loss(torch::log(prediction), labels); // totalLoss += loss.item<float>(); // loss.backward(); // optimizer.step(); // // if (++currentBatchNumber*batchSize % 1000 == 0) // { // fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); // std::fflush(stdout); // } // } // // fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); // } // // return 0; //} //