diff --git a/summit/multiview_platform/exec_classif.py b/summit/multiview_platform/exec_classif.py index 73b87d7786e0d034896a9046e4c3803e675cff53..09f708aa7593fb13e9db91ba2c0e88e61c82eca4 100644 --- a/summit/multiview_platform/exec_classif.py +++ b/summit/multiview_platform/exec_classif.py @@ -955,7 +955,6 @@ class Summit(BaseExec): metrics=self.metrics) return plif - def gen_single_multiview_arg_dictionary(self, classifier_name, arguments, nb_class, hps_kwargs, views_dictionary=None,): if classifier_name in arguments: @@ -975,10 +974,7 @@ class Summit(BaseExec): database_name=self.name, hps_type=self.hps_type, nb_cores=self.nb_cores, - metrics=self.metrics, - - ) - + metrics=self.metrics,) def extract_dict(self, classifier_config): """Reverse function of get_path_dict""" @@ -987,7 +983,6 @@ class Summit(BaseExec): extracted_dict = self.set_element(extracted_dict, key, value) return extracted_dict - def set_element(self, dictionary, path, value): """Set value in dictionary at the location indicated by path""" existing_keys = path.split(".")[:-1] @@ -1001,7 +996,6 @@ class Summit(BaseExec): dict_state[path.split(".")[-1]] = value return dictionary - def get_path_dict(self, multiview_classifier_args): """This function is used to generate a dictionary with each key being the path to the value. @@ -1018,7 +1012,6 @@ class Summit(BaseExec): paths = self.is_dict_in(path_dict) return path_dict - def is_dict_in(self, dictionary): """ Returns True if any of the dictionary value is a dictionary itself. @@ -1037,7 +1030,6 @@ class Summit(BaseExec): paths.append(key) return paths - def init_kwargs(self, classifiers_names, framework="monoview"): r"""Used to init kwargs thanks to a function in each monoview classifier package. @@ -1068,8 +1060,8 @@ class Summit(BaseExec): getattr(multiview_classifiers, classifiers_name) except AttributeError: raise AttributeError( - classifiers_name + " is not implemented in monoview_classifiers, " - "please specify the name of the file in monoview_classifiers") + classifiers_name + " is not implemented in {}_classifiers, " + "please specify the name of the file in {}_classifiers".format(framework, framework)) if classifiers_name in self.args: kwargs[classifiers_name] = self.args[classifiers_name] else: diff --git a/summit/multiview_platform/result_analysis/error_analysis.py b/summit/multiview_platform/result_analysis/error_analysis.py index 672404462e214d85b1ded0ea75b377cdbf6fffe8..96dbc9bba0cee7b831db65fc6c8ac684255d6424 100644 --- a/summit/multiview_platform/result_analysis/error_analysis.py +++ b/summit/multiview_platform/result_analysis/error_analysis.py @@ -55,9 +55,13 @@ def publish_sample_errors(sample_errors, directory, database_name, nb_classifiers, nb_samples, classifiers_names, \ data_2d, error_on_samples = gen_error_data(sample_errors) - np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",") - np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples, - delimiter=",") + heat_map_data = pd.DataFrame(index=sample_ids, columns=classifiers_names, data=data_2d) + bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples) + heat_map_data.to_csv(base_file_name + "2D_plot_data.csv") + bar_plot_data.to_csv(base_file_name + "bar_plot_data.csv") + # np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",") + # np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples, + # delimiter=",") plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name, sample_ids=sample_ids, labels=labels, label_names=label_names, test=test) @@ -82,9 +86,14 @@ def publish_all_sample_errors(iter_results, directory, add='t' else: add = "" - np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",") - np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples, - delimiter=",") + heat_map_data = pd.DataFrame(index=sample_ids, columns=classifier_names, + data=data) + bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples) + heat_map_data.to_csv(os.path.join(directory, "clf_errors{}.csv".format(add))) + bar_plot_data.to_csv(os.path.join(directory, "sample_errors{}.csv".format(add))) + # np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",") + # np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples, + # delimiter=",") df = pd.DataFrame(index = sample_ids, columns=["err"], data=1-error_on_samples) df.to_csv(os.path.join(directory, "sample_err_df{}.csv".format(add))) plot_2d(data, classifier_names, nb_classifiers, diff --git a/summit/multiview_platform/result_analysis/feature_importances.py b/summit/multiview_platform/result_analysis/feature_importances.py index ec233dc22b50e0e50ab5cfb18752852b7577705d..ad92c6cdf8e0314a4677b3d438370d9b3c809db6 100644 --- a/summit/multiview_platform/result_analysis/feature_importances.py +++ b/summit/multiview_platform/result_analysis/feature_importances.py @@ -25,6 +25,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): """ feature_importances = {} for classifier_result in result: + print((classifier_result.classifier_name)) if isinstance(classifier_result, MonoviewResult): if classifier_result.view_name not in feature_importances: feature_importances[classifier_result.view_name] = pd.DataFrame(