From d846104c2cc3cc4bf31f1a25029af0b567b90e19 Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Tue, 25 Feb 2020 01:44:36 +0100
Subject: [PATCH] Allow batch size 1

---
 trainer/src/Trainer.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 64466b5..afbf965 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -82,6 +82,8 @@ float Trainer::epoch(bool printAdvancement)
 
     auto prediction = machine.getClassifier()->getNN()(data);
 
+    labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
+
     auto loss = lossFct(prediction, labels);
     try
     {
-- 
GitLab