From be63334bb4116b31c4b0464dde155d2b5633e9bc Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Wed, 22 Jan 2020 22:26:50 +0100 Subject: [PATCH] tests --- dev/src/dev.cpp | 15 +++++++++++---- torch_modules/src/TestNetwork.cpp | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index b5bd81f..2aac98c 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -67,18 +67,20 @@ int main(int argc, char * argv[]) auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); - fmt::print("Done! size={}\n", *dataset.size()); + int nbExamples = *dataset.size(); + fmt::print("Done! size={}\n", nbExamples); int batchSize = 100; - auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize)); + auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - for (int epoch = 1; epoch <= 5; ++epoch) + for (int epoch = 1; epoch <= 1; ++epoch) { float totalLoss = 0.0; torch::Tensor example; + int currentBatchNumber = 0; for (auto & batch : *dataLoader) { @@ -94,10 +96,15 @@ int main(int argc, char * argv[]) totalLoss += loss.item<float>(); loss.backward(); optimizer.step(); + + if (++currentBatchNumber*batchSize % 1000 == 0) + { + fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); + std::fflush(stdout); + } } fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); - std::cout << example << std::endl; } return 0; diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index 3e3c010..f379c73 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -2,7 +2,7 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) { - constexpr int embeddingsSize = 100; + constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize)); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); this->focusedIndex = focusedIndex; -- GitLab