From d846104c2cc3cc4bf31f1a25029af0b567b90e19 Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Tue, 25 Feb 2020 01:44:36 +0100 Subject: [PATCH] Allow batch size 1 --- trainer/src/Trainer.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 64466b5..afbf965 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -82,6 +82,8 @@ float Trainer::epoch(bool printAdvancement) auto prediction = machine.getClassifier()->getNN()(data); + labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); + auto loss = lossFct(prediction, labels); try { -- GitLab