Skip to content
Snippets Groups Projects
Commit 540ea254 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

test

parent 142ddc23
Branches
No related tags found
No related merge requests found
Pipeline #11595 failed
...@@ -955,7 +955,6 @@ class Summit(BaseExec): ...@@ -955,7 +955,6 @@ class Summit(BaseExec):
metrics=self.metrics) metrics=self.metrics)
return plif return plif
def gen_single_multiview_arg_dictionary(self, classifier_name, arguments, nb_class, def gen_single_multiview_arg_dictionary(self, classifier_name, arguments, nb_class,
hps_kwargs, views_dictionary=None,): hps_kwargs, views_dictionary=None,):
if classifier_name in arguments: if classifier_name in arguments:
...@@ -975,10 +974,7 @@ class Summit(BaseExec): ...@@ -975,10 +974,7 @@ class Summit(BaseExec):
database_name=self.name, database_name=self.name,
hps_type=self.hps_type, hps_type=self.hps_type,
nb_cores=self.nb_cores, nb_cores=self.nb_cores,
metrics=self.metrics, metrics=self.metrics,)
)
def extract_dict(self, classifier_config): def extract_dict(self, classifier_config):
"""Reverse function of get_path_dict""" """Reverse function of get_path_dict"""
...@@ -987,7 +983,6 @@ class Summit(BaseExec): ...@@ -987,7 +983,6 @@ class Summit(BaseExec):
extracted_dict = self.set_element(extracted_dict, key, value) extracted_dict = self.set_element(extracted_dict, key, value)
return extracted_dict return extracted_dict
def set_element(self, dictionary, path, value): def set_element(self, dictionary, path, value):
"""Set value in dictionary at the location indicated by path""" """Set value in dictionary at the location indicated by path"""
existing_keys = path.split(".")[:-1] existing_keys = path.split(".")[:-1]
...@@ -1001,7 +996,6 @@ class Summit(BaseExec): ...@@ -1001,7 +996,6 @@ class Summit(BaseExec):
dict_state[path.split(".")[-1]] = value dict_state[path.split(".")[-1]] = value
return dictionary return dictionary
def get_path_dict(self, multiview_classifier_args): def get_path_dict(self, multiview_classifier_args):
"""This function is used to generate a dictionary with each key being """This function is used to generate a dictionary with each key being
the path to the value. the path to the value.
...@@ -1018,7 +1012,6 @@ class Summit(BaseExec): ...@@ -1018,7 +1012,6 @@ class Summit(BaseExec):
paths = self.is_dict_in(path_dict) paths = self.is_dict_in(path_dict)
return path_dict return path_dict
def is_dict_in(self, dictionary): def is_dict_in(self, dictionary):
""" """
Returns True if any of the dictionary value is a dictionary itself. Returns True if any of the dictionary value is a dictionary itself.
...@@ -1037,7 +1030,6 @@ class Summit(BaseExec): ...@@ -1037,7 +1030,6 @@ class Summit(BaseExec):
paths.append(key) paths.append(key)
return paths return paths
def init_kwargs(self, classifiers_names, framework="monoview"): def init_kwargs(self, classifiers_names, framework="monoview"):
r"""Used to init kwargs thanks to a function in each monoview classifier package. r"""Used to init kwargs thanks to a function in each monoview classifier package.
...@@ -1068,8 +1060,8 @@ class Summit(BaseExec): ...@@ -1068,8 +1060,8 @@ class Summit(BaseExec):
getattr(multiview_classifiers, classifiers_name) getattr(multiview_classifiers, classifiers_name)
except AttributeError: except AttributeError:
raise AttributeError( raise AttributeError(
classifiers_name + " is not implemented in monoview_classifiers, " classifiers_name + " is not implemented in {}_classifiers, "
"please specify the name of the file in monoview_classifiers") "please specify the name of the file in {}_classifiers".format(framework, framework))
if classifiers_name in self.args: if classifiers_name in self.args:
kwargs[classifiers_name] = self.args[classifiers_name] kwargs[classifiers_name] = self.args[classifiers_name]
else: else:
......
...@@ -55,9 +55,13 @@ def publish_sample_errors(sample_errors, directory, database_name, ...@@ -55,9 +55,13 @@ def publish_sample_errors(sample_errors, directory, database_name,
nb_classifiers, nb_samples, classifiers_names, \ nb_classifiers, nb_samples, classifiers_names, \
data_2d, error_on_samples = gen_error_data(sample_errors) data_2d, error_on_samples = gen_error_data(sample_errors)
np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",") heat_map_data = pd.DataFrame(index=sample_ids, columns=classifiers_names, data=data_2d)
np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples, bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
delimiter=",") heat_map_data.to_csv(base_file_name + "2D_plot_data.csv")
bar_plot_data.to_csv(base_file_name + "bar_plot_data.csv")
# np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",")
# np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
# delimiter=",")
plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name, plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name,
sample_ids=sample_ids, labels=labels, label_names=label_names, test=test) sample_ids=sample_ids, labels=labels, label_names=label_names, test=test)
...@@ -82,9 +86,14 @@ def publish_all_sample_errors(iter_results, directory, ...@@ -82,9 +86,14 @@ def publish_all_sample_errors(iter_results, directory,
add='t' add='t'
else: else:
add = "" add = ""
np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",") heat_map_data = pd.DataFrame(index=sample_ids, columns=classifier_names,
np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples, data=data)
delimiter=",") bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
heat_map_data.to_csv(os.path.join(directory, "clf_errors{}.csv".format(add)))
bar_plot_data.to_csv(os.path.join(directory, "sample_errors{}.csv".format(add)))
# np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",")
# np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples,
# delimiter=",")
df = pd.DataFrame(index = sample_ids, columns=["err"], data=1-error_on_samples) df = pd.DataFrame(index = sample_ids, columns=["err"], data=1-error_on_samples)
df.to_csv(os.path.join(directory, "sample_err_df{}.csv".format(add))) df.to_csv(os.path.join(directory, "sample_err_df{}.csv".format(add)))
plot_2d(data, classifier_names, nb_classifiers, plot_2d(data, classifier_names, nb_classifiers,
......
...@@ -25,6 +25,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,): ...@@ -25,6 +25,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
""" """
feature_importances = {} feature_importances = {}
for classifier_result in result: for classifier_result in result:
print((classifier_result.classifier_name))
if isinstance(classifier_result, MonoviewResult): if isinstance(classifier_result, MonoviewResult):
if classifier_result.view_name not in feature_importances: if classifier_result.view_name not in feature_importances:
feature_importances[classifier_result.view_name] = pd.DataFrame( feature_importances[classifier_result.view_name] = pd.DataFrame(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment