diff --git a/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py b/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py index 3ad80ca2ac5dab2de39af6b358286c90823857f3..0b7597d80eb72ad2f7a795418ffb0e65356137b8 100644 --- a/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py +++ b/multiview_platform/mono_multi_view_classifiers/monoview/exec_classif_mono_view.py @@ -150,12 +150,13 @@ def exec_monoview(directory, X, Y, database_name, labels_names, classification_i database_name=database_name, nb_cores=nb_cores, duration=whole_duration) - string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze() + string_analysis, images_analysis, metrics_scores, class_metrics_scores, \ + confusion_matrix = result_analyzer.analyze() logging.debug("Done:\t Getting results") logging.debug("Start:\t Saving preds") save_results(string_analysis, output_file_name, full_pred, train_pred, - y_train, images_analysis, y_test) + y_train, images_analysis, y_test, confusion_matrix) logging.info("Done:\t Saving results") view_index = args["view_index"] @@ -222,11 +223,13 @@ def get_hyper_params(classifier_module, search_method, classifier_module_name, def save_results(string_analysis, output_file_name, full_labels_pred, y_train_pred, - y_train, images_analysis, y_test): + y_train, images_analysis, y_test, confusion_matrix): logging.info(string_analysis) output_text_file = open(output_file_name + 'summary.txt', 'w') output_text_file.write(string_analysis) output_text_file.close() + np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix, + delimiter=', ') np.savetxt(output_file_name + "full_pred.csv", full_labels_pred.astype(np.int16), delimiter=",") np.savetxt(output_file_name + "train_pred.csv", diff --git a/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py b/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py index deebc7255491ab89688e7d33062b174fdd68d08d..f3a9a04b9c6c7bd5040c659c7ac10690d612cc8e 100644 --- a/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py +++ b/multiview_platform/mono_multi_view_classifiers/multiview/exec_multiview.py @@ -71,7 +71,8 @@ def init_constants(kwargs, classification_indices, metrics, directory, base_file_name -def save_results(string_analysis, images_analysis, output_file_name): +def save_results(string_analysis, images_analysis, output_file_name, + confusion_matrix): """ Save results in derectory @@ -104,6 +105,8 @@ def save_results(string_analysis, images_analysis, output_file_name): output_text_file = open(output_file_name + 'summary.txt', 'w') output_text_file.write(string_analysis) output_text_file.close() + np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix, + delimiter=',') if images_analysis is not None: for image_name in images_analysis.keys(): @@ -339,11 +342,12 @@ def exec_multiview(directory, dataset_var, name, classification_indices, database_name=dataset_var.get_name(), nb_cores=nb_cores, duration=whole_duration) - string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze() + string_analysis, images_analysis, metrics_scores, class_metrics_scores, \ + confusion_matrix = result_analyzer.analyze() logging.info("Done:\t Result Analysis for " + cl_type) logging.debug("Start:\t Saving preds") - save_results(string_analysis, images_analysis, output_file_name) + save_results(string_analysis, images_analysis, output_file_name, confusion_matrix) logging.debug("Start:\t Saving preds") return MultiviewResult(cl_type, classifier_config, metrics_scores, diff --git a/multiview_platform/mono_multi_view_classifiers/utils/base.py b/multiview_platform/mono_multi_view_classifiers/utils/base.py index 013ca1b57da87dbdd5e0f900413a7bf680d7cb4b..de2a280b9b0f9fb6a3524df2f50aa2b74b7e992f 100644 --- a/multiview_platform/mono_multi_view_classifiers/utils/base.py +++ b/multiview_platform/mono_multi_view_classifiers/utils/base.py @@ -285,7 +285,7 @@ class ResultAnalyser(): metric_score_string += "\n\t\t- Score on test : {}".format(self.metric_scores[metric][1]) metric_score_string += "\n\n" metric_score_string += "Test set confusion matrix : \n\n" - confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices]) + self.confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices]) formatted_conf = [[label_name]+list(row) for label_name, row in zip(self.class_label_names, confusion_matrix)] metric_score_string+=tabulate(formatted_conf, headers= ['']+self.class_label_names, tablefmt='fancy_grid') metric_score_string += "\n\n" @@ -361,7 +361,8 @@ class ResultAnalyser(): self.directory, self.base_file_name, self.labels[self.test_indices]) image_analysis = {} - return string_analysis, image_analysis, self.metric_scores, self.class_metric_scores + return string_analysis, image_analysis, self.metric_scores, \ + self.class_metric_scores, self.confusion_matrix base_boosting_estimators = [DecisionTreeClassifier(max_depth=1),