diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 64466b5074ebd2bd9397208638dac6c69b5b2185..afbf9652f1b0152d73741b4fa341cf725a8e7fc1 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -82,6 +82,8 @@ float Trainer::epoch(bool printAdvancement)
 
     auto prediction = machine.getClassifier()->getNN()(data);
 
+    labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0));
+
     auto loss = lossFct(prediction, labels);
     try
     {