diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 64466b5074ebd2bd9397208638dac6c69b5b2185..afbf9652f1b0152d73741b4fa341cf725a8e7fc1 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -82,6 +82,8 @@ float Trainer::epoch(bool printAdvancement) auto prediction = machine.getClassifier()->getNN()(data); + labels = labels.reshape(labels.dim() == 0 ? 1 : labels.size(0)); + auto loss = lossFct(prediction, labels); try {