diff --git a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoost.py b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoost.py index ea3c4f1eaffc39da995a16132522ec5e9b9734e4..68497d530f747829c68de54b8167119f47f39e7d 100644 --- a/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoost.py +++ b/multiview_platform/MonoMultiViewClassifiers/MonoviewClassifiers/CQBoost.py @@ -24,6 +24,16 @@ class CQBoost(ColumnGenerationClassifier, BaseMonoviewClassifier): def getInterpret(self, directory, y_test): np.savetxt(directory + "train_metrics.csv", self.train_metrics, delimiter=',') + np.savetxt(directory + "y_test_step.csv", self.step_decisions, + delimiter=',') + step_metrics = [] + for step_index in range(self.step_decisions.shape[1] - 1): + step_metrics.append(self.plotted_metric.score(y_test, + self.step_decisions[:, + step_index])) + step_metrics = np.array(step_metrics) + np.savetxt(directory + "step_test_metrics.csv", step_metrics, + delimiter=',') return getInterpretBase(self, directory, "CQBoost", self.weights_, y_test)