Skip to content
Snippets Groups Projects
Commit cf04f1ff authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Merge branch 'develop' into private_algos

parents 3aeb6170 51317038
No related branches found
No related tags found
No related merge requests found
...@@ -150,12 +150,13 @@ def exec_monoview(directory, X, Y, database_name, labels_names, classification_i ...@@ -150,12 +150,13 @@ def exec_monoview(directory, X, Y, database_name, labels_names, classification_i
database_name=database_name, database_name=database_name,
nb_cores=nb_cores, nb_cores=nb_cores,
duration=whole_duration) duration=whole_duration)
string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze() string_analysis, images_analysis, metrics_scores, class_metrics_scores, \
confusion_matrix = result_analyzer.analyze()
logging.debug("Done:\t Getting results") logging.debug("Done:\t Getting results")
logging.debug("Start:\t Saving preds") logging.debug("Start:\t Saving preds")
save_results(string_analysis, output_file_name, full_pred, train_pred, save_results(string_analysis, output_file_name, full_pred, train_pred,
y_train, images_analysis, y_test) y_train, images_analysis, y_test, confusion_matrix)
logging.info("Done:\t Saving results") logging.info("Done:\t Saving results")
view_index = args["view_index"] view_index = args["view_index"]
...@@ -222,11 +223,13 @@ def get_hyper_params(classifier_module, search_method, classifier_module_name, ...@@ -222,11 +223,13 @@ def get_hyper_params(classifier_module, search_method, classifier_module_name,
def save_results(string_analysis, output_file_name, full_labels_pred, def save_results(string_analysis, output_file_name, full_labels_pred,
y_train_pred, y_train_pred,
y_train, images_analysis, y_test): y_train, images_analysis, y_test, confusion_matrix):
logging.info(string_analysis) logging.info(string_analysis)
output_text_file = open(output_file_name + 'summary.txt', 'w') output_text_file = open(output_file_name + 'summary.txt', 'w')
output_text_file.write(string_analysis) output_text_file.write(string_analysis)
output_text_file.close() output_text_file.close()
np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix,
delimiter=', ')
np.savetxt(output_file_name + "full_pred.csv", np.savetxt(output_file_name + "full_pred.csv",
full_labels_pred.astype(np.int16), delimiter=",") full_labels_pred.astype(np.int16), delimiter=",")
np.savetxt(output_file_name + "train_pred.csv", np.savetxt(output_file_name + "train_pred.csv",
......
...@@ -71,7 +71,8 @@ def init_constants(kwargs, classification_indices, metrics, ...@@ -71,7 +71,8 @@ def init_constants(kwargs, classification_indices, metrics,
directory, base_file_name directory, base_file_name
def save_results(string_analysis, images_analysis, output_file_name): def save_results(string_analysis, images_analysis, output_file_name,
confusion_matrix):
""" """
Save results in derectory Save results in derectory
...@@ -104,6 +105,8 @@ def save_results(string_analysis, images_analysis, output_file_name): ...@@ -104,6 +105,8 @@ def save_results(string_analysis, images_analysis, output_file_name):
output_text_file = open(output_file_name + 'summary.txt', 'w') output_text_file = open(output_file_name + 'summary.txt', 'w')
output_text_file.write(string_analysis) output_text_file.write(string_analysis)
output_text_file.close() output_text_file.close()
np.savetxt(output_file_name+"confusion_matrix.csv", confusion_matrix,
delimiter=',')
if images_analysis is not None: if images_analysis is not None:
for image_name in images_analysis.keys(): for image_name in images_analysis.keys():
...@@ -339,11 +342,12 @@ def exec_multiview(directory, dataset_var, name, classification_indices, ...@@ -339,11 +342,12 @@ def exec_multiview(directory, dataset_var, name, classification_indices,
database_name=dataset_var.get_name(), database_name=dataset_var.get_name(),
nb_cores=nb_cores, nb_cores=nb_cores,
duration=whole_duration) duration=whole_duration)
string_analysis, images_analysis, metrics_scores, class_metrics_scores = result_analyzer.analyze() string_analysis, images_analysis, metrics_scores, class_metrics_scores, \
confusion_matrix = result_analyzer.analyze()
logging.info("Done:\t Result Analysis for " + cl_type) logging.info("Done:\t Result Analysis for " + cl_type)
logging.debug("Start:\t Saving preds") logging.debug("Start:\t Saving preds")
save_results(string_analysis, images_analysis, output_file_name) save_results(string_analysis, images_analysis, output_file_name, confusion_matrix)
logging.debug("Start:\t Saving preds") logging.debug("Start:\t Saving preds")
return MultiviewResult(cl_type, classifier_config, metrics_scores, return MultiviewResult(cl_type, classifier_config, metrics_scores,
......
...@@ -285,7 +285,7 @@ class ResultAnalyser(): ...@@ -285,7 +285,7 @@ class ResultAnalyser():
metric_score_string += "\n\t\t- Score on test : {}".format(self.metric_scores[metric][1]) metric_score_string += "\n\t\t- Score on test : {}".format(self.metric_scores[metric][1])
metric_score_string += "\n\n" metric_score_string += "\n\n"
metric_score_string += "Test set confusion matrix : \n\n" metric_score_string += "Test set confusion matrix : \n\n"
confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices]) self.confusion_matrix = confusion(y_true=self.labels[self.test_indices], y_pred=self.pred[self.test_indices])
formatted_conf = [[label_name]+list(row) for label_name, row in zip(self.class_label_names, confusion_matrix)] formatted_conf = [[label_name]+list(row) for label_name, row in zip(self.class_label_names, confusion_matrix)]
metric_score_string+=tabulate(formatted_conf, headers= ['']+self.class_label_names, tablefmt='fancy_grid') metric_score_string+=tabulate(formatted_conf, headers= ['']+self.class_label_names, tablefmt='fancy_grid')
metric_score_string += "\n\n" metric_score_string += "\n\n"
...@@ -361,7 +361,8 @@ class ResultAnalyser(): ...@@ -361,7 +361,8 @@ class ResultAnalyser():
self.directory, self.base_file_name, self.directory, self.base_file_name,
self.labels[self.test_indices]) self.labels[self.test_indices])
image_analysis = {} image_analysis = {}
return string_analysis, image_analysis, self.metric_scores, self.class_metric_scores return string_analysis, image_analysis, self.metric_scores, \
self.class_metric_scores, self.confusion_matrix
base_boosting_estimators = [DecisionTreeClassifier(max_depth=1), base_boosting_estimators = [DecisionTreeClassifier(max_depth=1),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment