Skip to content
Snippets Groups Projects
Commit 51317038 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Saving confusion matrinx as csv

parent 0f683f7f
No related branches found
No related tags found
No related merge requests found
......@@ -150,12 +150,13 @@ def exec_monoview(directory, X, Y, database_name, labels_names, classification_i
database_name=database_name,
nb_cores=nb_cores,
duration=whole_duration)
string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze()
string_analysis, images_analysis, metrics_scores, class_metrics_scores, \
confusion_matrix = result_analyzer.analyze()
logging.debug("Done:\t Getting results")
logging.debug("Start:\t Saving preds")
save_results(string_analysis, output_file_name, full_pred, train_pred,
y_train, images_analysis, y_test)
y_train, images_analysis, y_test, confusion_matrix)
logging.info("Done:\t Saving results")
view_index = args["view_index"]
......@@ -222,11 +223,13 @@ def get_hyper_params(classifier_module, search_method, classifier_module_name,
def save_results(string_analysis, output_file_name, full_labels_pred,
y_train_pred,
y_train, images_analysis, y_test):
y_train, images_analysis, y_test, confusion_matrix):
logging.info(string_analysis)
output_text_file = open(output_file_name + 'summary.txt', 'w')
output_text_file.write(string_analysis)
output_text_file.close()
np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix,
delimiter=', ')
np.savetxt(output_file_name + "full_pred.csv",
full_labels_pred.astype(np.int16), delimiter=",")
np.savetxt(output_file_name + "train_pred.csv",
......
......@@ -71,7 +71,8 @@ def init_constants(kwargs, classification_indices, metrics,
directory, base_file_name
def save_results(string_analysis, images_analysis, output_file_name):
def save_results(string_analysis, images_analysis, output_file_name,
confusion_matrix):
"""
Save results in derectory
......@@ -104,6 +105,8 @@ def save_results(string_analysis, images_analysis, output_file_name):
output_text_file = open(output_file_name + 'summary.txt', 'w')
output_text_file.write(string_analysis)
output_text_file.close()
np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix,
delimiter=',')
if images_analysis is not None:
for image_name in images_analysis.keys():
......@@ -339,11 +342,12 @@ def exec_multiview(directory, dataset_var, name, classification_indices,
database_name=dataset_var.get_name(),
nb_cores=nb_cores,
duration=whole_duration)
string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze()
string_analysis, images_analysis, metrics_scores, class_metrics_scores, \
confusion_matrix = result_analyzer.analyze()
logging.info("Done:\t Result Analysis for " + cl_type)
logging.debug("Start:\t Saving preds")
save_results(string_analysis, images_analysis, output_file_name)
save_results(string_analysis, images_analysis, output_file_name, confusion_matrix)
logging.debug("Start:\t Saving preds")
return MultiviewResult(cl_type, classifier_config, metrics_scores,
......
......@@ -285,7 +285,7 @@ class ResultAnalyser():
metric_score_string += "\n\t\t- Score on test : {}".format(self.metric_scores[metric][1])
metric_score_string += "\n\n"
metric_score_string += "Test set confusion matrix : \n\n"
confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices])
self.confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices])
formatted_conf = [[label_name]+list(row) for label_name, row in zip(self.class_label_names, confusion_matrix)]
metric_score_string+=tabulate(formatted_conf, headers= ['']+self.class_label_names, tablefmt='fancy_grid')
metric_score_string += "\n\n"
......@@ -361,7 +361,8 @@ class ResultAnalyser():
self.directory, self.base_file_name,
self.labels[self.test_indices])
image_analysis = {}
return string_analysis, image_analysis, self.metric_scores, self.class_metric_scores
return string_analysis, image_analysis, self.metric_scores, \
self.class_metric_scores, self.confusion_matrix
base_boosting_estimators = [DecisionTreeClassifier(max_depth=1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment