From be63334bb4116b31c4b0464dde155d2b5633e9bc Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Wed, 22 Jan 2020 22:26:50 +0100
Subject: [PATCH] tests

---
 dev/src/dev.cpp                   | 15 +++++++++++----
 torch_modules/src/TestNetwork.cpp |  2 +-
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index b5bd81f..2aac98c 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -67,18 +67,20 @@ int main(int argc, char * argv[])
 
   auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
 
-  fmt::print("Done! size={}\n", *dataset.size());
+  int nbExamples = *dataset.size();
+  fmt::print("Done! size={}\n", nbExamples);
 
   int batchSize = 100;
-  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize));
+  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
 
   TestNetwork nn(machine.getTransitionSet().size(), 5);
   torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
 
-  for (int epoch = 1; epoch <= 5; ++epoch)
+  for (int epoch = 1; epoch <= 1; ++epoch)
   {
     float totalLoss = 0.0;
     torch::Tensor example;
+    int currentBatchNumber = 0;
 
     for (auto & batch : *dataLoader)
     {
@@ -94,10 +96,15 @@ int main(int argc, char * argv[])
       totalLoss += loss.item<float>();
       loss.backward();
       optimizer.step();
+
+      if (++currentBatchNumber*batchSize % 1000 == 0)
+      {
+        fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples);
+        std::fflush(stdout);
+      }
     }
 
     fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
-    std::cout << example << std::endl;
   }
 
   return 0;
diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp
index 3e3c010..f379c73 100644
--- a/torch_modules/src/TestNetwork.cpp
+++ b/torch_modules/src/TestNetwork.cpp
@@ -2,7 +2,7 @@
 
 TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
 {
-  constexpr int embeddingsSize = 100;
+  constexpr int embeddingsSize = 30;
   wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
   linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
   this->focusedIndex = focusedIndex;
-- 
GitLab