From ce7290d1827f09272ef8d6174f11900b729142ad Mon Sep 17 00:00:00 2001
From: Baptiste Bauvin <baptiste.bauvin@lis-lab.fr>
Date: Tue, 17 Dec 2019 09:29:50 -0500
Subject: [PATCH] Doc

---
 config_files/config_test.yml                  |  2 +-
 docs/source/tutorials/example4.rst            |  2 +-
 .../result_analysis.py                        | 45 ++++++++++---------
 3 files changed, 27 insertions(+), 22 deletions(-)

diff --git a/config_files/config_test.yml b/config_files/config_test.yml
index 5633883c..1914f94e 100644
--- a/config_files/config_test.yml
+++ b/config_files/config_test.yml
@@ -23,7 +23,7 @@ Classification:
   nb_class: 2
   classes:
   type: ["monoview"]
-  algos_monoview: ["all", ]
+  algos_monoview: ["adaboost",]
   algos_multiview: ["weighted_linear_early_fusion"]
   stats_iter: 2
   metrics: ["accuracy_score", "f1_score"]
diff --git a/docs/source/tutorials/example4.rst b/docs/source/tutorials/example4.rst
index 6fdd70a4..a6ba2f79 100644
--- a/docs/source/tutorials/example4.rst
+++ b/docs/source/tutorials/example4.rst
@@ -118,7 +118,7 @@ In order to be able to analyze the results with more clarity, one can add the ex
 
 Let's suppose that the objects we are trying to classify between 'Animal' and 'Object' are all bears, cars, planes, and birds. And that one has a ``.csv`` file with an ID for each of them (:python:`"bear_112", "plane_452", "bird_785", "car_369", ...` for example)
 
-Then as long as the IDs order conresspond to the example order in the lines of the previous matrices, to add the IDs in the hdf5 file, just add :
+Then as long as the IDs order corresponds to the example order in the lines of the previous matrices, to add the IDs in the hdf5 file, just add :
 
 .. code-block:: python
 
diff --git a/multiview_platform/mono_multi_view_classifiers/result_analysis.py b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
index 53178701..5cb0318d 100644
--- a/multiview_platform/mono_multi_view_classifiers/result_analysis.py
+++ b/multiview_platform/mono_multi_view_classifiers/result_analysis.py
@@ -136,7 +136,7 @@ def plot_metric_scores(train_scores, test_scores, names, nb_results, metric_name
 
 
 def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
-            fileName, minSize=10,
+            fileName, minSize=10, labels=None,
             width_denominator=2.0, height_denominator=20.0, stats_iter=1,
             use_plotly=True, example_ids=None):
     r"""Used to generate a 2D plot of the errors.
@@ -174,8 +174,8 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
                      aspect='auto')
     plt.title('Errors depending on the classifier')
     ticks = np.arange(0, nbClassifiers, 1)
-    labels = classifiers_names
-    plt.xticks(ticks, labels, rotation="vertical")
+    tick_labels = classifiers_names
+    plt.xticks(ticks, tick_labels, rotation="vertical")
     cbar = fig.colorbar(cax, ticks=[-100 * stats_iter / 2, 0, stats_iter])
     cbar.ax.set_yticklabels(['Unseen', 'Always Wrong', 'Always Right'])
 
@@ -183,22 +183,27 @@ def plot_2d(data, classifiers_names, nbClassifiers, nbExamples,
     plt.close()
     ### The following part is used to generate an interactive graph.
     if use_plotly:
-        hover_text = [["Failed "+ str(stats_iter-data[i,j])+" time(s)"
+        label_index_list = [np.where(labels==i)[0] for i in np.unique(labels)]
+        print(label_index_list)
+        hover_text = [[example_ids[i] + " failed "+ str(stats_iter-data[i,j])+" time(s)"
                        for j in range(data.shape[1])]
                       for i in range(data.shape[0]) ]
-        fig = plotly.graph_objs.Figure(data=plotly.graph_objs.Heatmap(
-            x=list(classifiers_names),
-            y=[_ for _ in example_ids],
-            z=data,
-            text=hover_text,
-            hoverinfo=["y", "x", "text"],
-            colorscale="Greys",
-            colorbar=dict(tickvals=[0, stats_iter],
-                          ticktext=["Always Wrong", "Always Right"]),
-            reversescale=True))
-        fig.update_layout(
-            xaxis={"showgrid": False, "showticklabels": False, "ticks": ''},
-            yaxis={"showgrid": False, "showticklabels": False, "ticks": ''})
+        fig = plotly.subplots.make_subplots(rows=len(label_index_list), cols=1)
+        for row_index, label_index in enumerate(label_index_list):
+            fig.add_trace(plotly.graph_objs.Heatmap(
+                x=list(classifiers_names),
+                y=[example_ids[label_ind] for label_ind in label_index],
+                z=data[label_index, :],
+                text=hover_text,
+                hoverinfo=["y", "x", "text"],
+                colorscale="Greys",
+                colorbar=dict(tickvals=[0, stats_iter],
+                              ticktext=["Always Wrong", "Always Right"]),
+                reversescale=True), row=row_index+1, col=1)
+            fig.update_yaxes(title_text="Label "+str(row_index), showticklabels=False, ticks='', row=row_index+1, col=1)
+            fig.update_xaxes(showticklabels=False, row=row_index+1, col=1)
+
+        fig.update_xaxes(showticklabels=True, row=len(label_index_list), col=1)
         plotly.offline.plot(fig, filename=fileName + "error_analysis_2D.html", auto_open=False)
         del fig
 
@@ -537,7 +542,7 @@ def gen_error_data(example_errors):
     return nb_classifiers, nb_examples, classifiers_names, data_2d, error_on_examples
 
 
-def publishExampleErrors(example_errors, directory, databaseName, labels_names, example_ids):
+def publishExampleErrors(example_errors, directory, databaseName, labels_names, example_ids, labels):
     logging.debug("Start:\t Biclass Label analysis figure generation")
 
     base_file_name = directory + time.strftime(
@@ -552,7 +557,7 @@ def publishExampleErrors(example_errors, directory, databaseName, labels_names,
                delimiter=",")
 
     plot_2d(data_2d, classifiers_names, nb_classifiers, nb_examples,
-            base_file_name, example_ids=example_ids)
+            base_file_name, example_ids=example_ids, labels=labels)
 
     plot_errors_bar(error_on_examples, nb_classifiers, nb_examples,
                     base_file_name)
@@ -689,7 +694,7 @@ def analyze_biclass(results, benchmark_argument_dictionaries, stats_iter, metric
         results = publishMetricsGraphs(metrics_scores, directory, database_name,
                              labels_names)
         publishExampleErrors(example_errors, directory, database_name,
-                             labels_names, example_ids)
+                             labels_names, example_ids, arguments["labels"])
         publish_feature_importances(feature_importances, directory, database_name, labels_names)
         if not str(classifierPositive) + str(classifierNegative) in biclass_results:
             biclass_results[str(classifierPositive) + str(classifierNegative)] = {}
-- 
GitLab