Skip to content
Snippets Groups Projects
Commit 540ea254 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

test

parent 142ddc23
No related branches found
No related tags found
No related merge requests found
Pipeline #11595 failed
......@@ -955,7 +955,6 @@ class Summit(BaseExec):
metrics=self.metrics)
return plif
def gen_single_multiview_arg_dictionary(self, classifier_name, arguments, nb_class,
hps_kwargs, views_dictionary=None,):
if classifier_name in arguments:
......@@ -975,10 +974,7 @@ class Summit(BaseExec):
database_name=self.name,
hps_type=self.hps_type,
nb_cores=self.nb_cores,
metrics=self.metrics,
)
metrics=self.metrics,)
def extract_dict(self, classifier_config):
"""Reverse function of get_path_dict"""
......@@ -987,7 +983,6 @@ class Summit(BaseExec):
extracted_dict = self.set_element(extracted_dict, key, value)
return extracted_dict
def set_element(self, dictionary, path, value):
"""Set value in dictionary at the location indicated by path"""
existing_keys = path.split(".")[:-1]
......@@ -1001,7 +996,6 @@ class Summit(BaseExec):
dict_state[path.split(".")[-1]] = value
return dictionary
def get_path_dict(self, multiview_classifier_args):
"""This function is used to generate a dictionary with each key being
the path to the value.
......@@ -1018,7 +1012,6 @@ class Summit(BaseExec):
paths = self.is_dict_in(path_dict)
return path_dict
def is_dict_in(self, dictionary):
"""
Returns True if any of the dictionary value is a dictionary itself.
......@@ -1037,7 +1030,6 @@ class Summit(BaseExec):
paths.append(key)
return paths
def init_kwargs(self, classifiers_names, framework="monoview"):
r"""Used to init kwargs thanks to a function in each monoview classifier package.
......@@ -1068,8 +1060,8 @@ class Summit(BaseExec):
getattr(multiview_classifiers, classifiers_name)
except AttributeError:
raise AttributeError(
classifiers_name + " is not implemented in monoview_classifiers, "
"please specify the name of the file in monoview_classifiers")
classifiers_name + " is not implemented in {}_classifiers, "
"please specify the name of the file in {}_classifiers".format(framework, framework))
if classifiers_name in self.args:
kwargs[classifiers_name] = self.args[classifiers_name]
else:
......
......@@ -55,9 +55,13 @@ def publish_sample_errors(sample_errors, directory, database_name,
nb_classifiers, nb_samples, classifiers_names, \
data_2d, error_on_samples = gen_error_data(sample_errors)
np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",")
np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
delimiter=",")
heat_map_data = pd.DataFrame(index=sample_ids, columns=classifiers_names, data=data_2d)
bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
heat_map_data.to_csv(base_file_name + "2D_plot_data.csv")
bar_plot_data.to_csv(base_file_name + "bar_plot_data.csv")
# np.savetxt(base_file_name + "2D_plot_data.csv", data_2d, delimiter=",")
# np.savetxt(base_file_name + "bar_plot_data.csv", error_on_samples,
# delimiter=",")
plot_2d(data_2d, classifiers_names, nb_classifiers, base_file_name, database_name,
sample_ids=sample_ids, labels=labels, label_names=label_names, test=test)
......@@ -82,9 +86,14 @@ def publish_all_sample_errors(iter_results, directory,
add='t'
else:
add = ""
np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",")
np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples,
delimiter=",")
heat_map_data = pd.DataFrame(index=sample_ids, columns=classifier_names,
data=data)
bar_plot_data = pd.DataFrame(index=sample_ids, data=error_on_samples)
heat_map_data.to_csv(os.path.join(directory, "clf_errors{}.csv".format(add)))
bar_plot_data.to_csv(os.path.join(directory, "sample_errors{}.csv".format(add)))
# np.savetxt(os.path.join(directory, "clf_errors{}.csv".format(add)), data, delimiter=",")
# np.savetxt(os.path.join(directory, "sample_errors{}.csv".format(add)), error_on_samples,
# delimiter=",")
df = pd.DataFrame(index = sample_ids, columns=["err"], data=1-error_on_samples)
df.to_csv(os.path.join(directory, "sample_err_df{}.csv".format(add)))
plot_2d(data, classifier_names, nb_classifiers,
......
......@@ -25,6 +25,7 @@ def get_feature_importances(result, feature_ids=None, view_names=None,):
"""
feature_importances = {}
for classifier_result in result:
print((classifier_result.classifier_name))
if isinstance(classifier_result, MonoviewResult):
if classifier_result.view_name not in feature_importances:
feature_importances[classifier_result.view_name] = pd.DataFrame(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment