diff --git a/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py b/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py index 9ef17e0d71e8e3a90dcaed6491c6fbc239cddf79..4f97eb04c7ac7f506ba616eca756c6d2eed66945 100644 --- a/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py +++ b/multiview_platform/MonoMultiViewClassifiers/ResultAnalysis.py @@ -244,7 +244,17 @@ def plotMetricScores(trainScores, testScores, names, nbResults, metricName, file f.savefig(fileName+'.png') plt.close() import pandas as pd - dataframe = pd.DataFrame(np.transpose(np.concatenate((trainScores.reshape((trainScores.shape[0], 1)), testScores.reshape((trainScores.shape[0], 1))), axis=1)), columns=names) + if train_STDs is None : + dataframe = pd.DataFrame(np.transpose(np.concatenate(( + trainScores.reshape((trainScores.shape[0], 1)), + testScores.reshape((trainScores.shape[0], 1))), axis=1)), + columns=names) + else: + dataframe = pd.DataFrame(np.transpose(np.concatenate(( + trainScores.reshape((trainScores.shape[0], 1)), + train_STDs.reshape((trainScores.shape[0], 1)), + testScores.reshape((trainScores.shape[0], 1)), + test_STDs.reshape((trainScores.shape[0], 1))), axis=1)), columns=names) dataframe.to_csv(fileName+".csv")