Commit 9888c471 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Fresh verison

parent f52062e9
__pycache__
demo/*.hdf5
demo/*.html
multiview_generator.egg-info
\ No newline at end of file
multiview_generator.egg-info
demo/tutorials/.ipy*
demo/tutorials/supplementray_material/tuto/
\ No newline at end of file
from . import multiview_generator
from . import demo
......@@ -49,6 +49,7 @@ def make_fig(conf, confusion_output, n_views, n_classes, generator):
{'type': 'scatter3d'}, ]])
row = 1
col = 1
show_legend = True
for view_index in range(n_views):
for lab_index in range(n_classes):
concerned_examples = np.where(generator.y == lab_index)[0]
......@@ -59,11 +60,14 @@ def make_fig(conf, confusion_output, n_views, n_classes, generator):
z=generator.view_data[view_index][concerned_examples, 2],
text=[generator.example_ids[ind] for ind in concerned_examples],
hoverinfo='text',
legendgroup="Class {}".format(lab_index),
mode='markers', marker=dict(
size=1, # set color to an array/list of desired values
color=DEFAULT_PLOTLY_COLORS[lab_index],
opacity=0.8
), name="Class {}".format(lab_index)), row=row, col=col)
), name="Class {}".format(lab_index), showlegend=show_legend),
row=row, col=col)
show_legend = False
# fig.update_layout(
# scene=dict(
# xaxis=dict(nticks=4, range=[low_range, high_range], ),
......
n_views: 4
n_classes: 3
confusion_matrix:
error_matrix:
- [0.4, 0.4, 0.4, 0.4]
- [0.55, 0.4, 0.4, 0.4]
- [0.4, 0.5, 0.52, 0.55]
# - [0.4, 0.5, 0.5, 0.4]
# - [0.4, 0.4, 0.4, 0.4]
# - [0.4, 0.4, 0.4, 0.4]
# - [0.4, 0.4, 0.4, 0.4]
# - [0.4, 0.4, 0.4, 0.4]
n_samples: 2000
n_features: 3
n_informative: 3
class_seps: 10
class_weights: [0.125, 0.125, 0.125,]# 0.125, 0.125, 0.125, 0.125, 0.125,]
mutual_error: 0.2
redundancy: 0.1
complementarity: 0.35
name: "doc_summit"
mutual_error: [0.4, 0.4,0.4]
redundancy: [0.5,0.4,0.4]
complementarity: [0.1, 0.05,0.05]
name: "demo"
sub_problem_type: ["base", "base", "base", "gaussian"]
......@@ -4,10 +4,11 @@ from classify_generated import gen_folds, make_fig, test_dataset
n_views = 4
n_classes = 3
gene = MultiViewSubProblemsGenerator(config_file="config_generator.yml")
gene = MultiViewSubProblemsGenerator(config_file="config_demo.yml")
conf = np.ones((n_classes, n_views))*0.4
gene.generate_multi_view_dataset()
gene.to_hdf5_mc()
print(gene.gen_report())
folds = gen_folds(random_state=42, generator=gene)
output_confusion = test_dataset(folds, n_views, n_classes, gene)
......
%% Cell type:code id: tags:
``` python
import numpy as np
from multiview_generator.multiple_sub_problems import MultiViewSubProblemsGenerator
n_views=3
n_classes=3
complementarity = np.array([0.3 for _ in range(n_classes)]).reshape((n_classes, 1))
complementarity_level = np.array([0.5 for _ in range(n_classes)]).reshape((n_classes, 1))
n_examples_per_class = np.array([100 for _ in range(n_classes)])
available_init_indices = [[i+(100*class_ind) for i in range(100)] for class_ind in range(n_classes)]
error_matrix = np.zeros((12,12))
complementarity_examples = [_ for _ in range(n_classes)]
good_views_indices = [_ for _ in range(n_classes)]
bad_views_indices = [_ for _ in range(n_classes)]
rs = np.random.RandomState(42)
example_ids = np.zeros(sum(n_examples_per_class), dtype="S100")
print(available_init_indices)
```
%% Output
[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199], [200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299]]
%% Cell type:code id: tags:
``` python
def _remove_available(available_indices, to_remove, class_index):
"""
Removes indices from the available ones array
"""
available_indices[class_index] = [ind
for ind
in available_indices[class_index]
if ind not in to_remove]
return available_indices
def _update_example_indices(target, target_name, class_ind):
for ind, target_ind in enumerate(target):
example_ids[target_ind] = target_name + "_{}_{}".format(ind, class_ind)
```
%% Output
[[array([0, 2]), array([0, 2]), array([0, 2]), array([0, 1]), array([1, 2]), array([0, 1]), array([0, 2]), array([0, 1]), array([0, 2]), array([1, 2]), array([0, 2]), array([0, 1]), array([0, 2]), array([0, 2]), array([1, 2]), array([0, 2]), array([0, 1]), array([0, 2]), array([0, 1]), array([0, 2]), array([0, 2]), array([1, 2]), array([0, 2]), array([1, 2]), array([0, 2]), array([0, 1]), array([0, 1]), array([0, 2]), array([0, 2]), array([0, 2])], [array([0, 2]), array([1, 2]), array([0, 2]), array([0, 2]), array([0, 1]), array([1, 2]), array([0, 1]), array([1, 2]), array([1, 2]), array([0, 1]), array([0, 1]), array([0, 1]), array([1, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([0, 2]), array([1, 2]), array([1, 2]), array([0, 1]), array([0, 2]), array([1, 2]), array([0, 1]), array([0, 1]), array([0, 2]), array([0, 1]), array([1, 2]), array([1, 2]), array([1, 2]), array([0, 2])], [array([1, 2]), array([0, 2]), array([0, 1]), array([0, 1]), array([1, 2]), array([0, 1]), array([1, 2]), array([0, 1]), array([1, 2]), array([1, 2]), array([1, 2]), array([0, 2]), array([0, 1]), array([0, 2]), array([0, 1]), array([0, 1]), array([0, 2]), array([0, 1]), array([1, 2]), array([0, 2]), array([0, 2]), array([0, 1]), array([0, 1]), array([0, 1]), array([0, 1]), array([0, 2]), array([0, 2]), array([1, 2]), array([0, 2]), array([1, 2])]]
%% Cell type:code id: tags:
``` python
n_bad = [int(complementarity_level[class_index]*n_views)
for class_index in range(n_classes)]
n_bad
```
%% Output
[1, 1, 1]
%% Cell type:markdown id: tags:
Complementarity is defined by class. Which means that samples of class i can be very complementary for example.
To check if the setting is compatible with the error matrix,
now we check if there is enough available indices that are not redundant or mutual error.
%% Cell type:code id: tags:
``` python
((complementarity * n_examples_per_class)[0] > np.array(
[len(inds) for inds in available_init_indices]))
```
%% Output
array([False, False, False])
%% Cell type:code id: tags:
``` python
for class_index, complementarity in enumerate(complementarity):
n_comp = int(complementarity_level[class_index]*n_views)
complementarity_examples[class_index] = rs.choice(
available_init_indices[class_index],
size=int(n_examples_per_class[
class_index] * complementarity),
replace=False)
_update_example_indices(
complementarity_examples[class_index],
'Complementary', class_index)
good_views_indices[class_index] = [
rs.choice(np.arange(n_views),
size=n_bad,
replace=False)
for _ in complementarity_examples[class_index]]
bad_views_indices[class_index] = [np.array([ind
for ind
in range(
n_views)
if ind not in
good_views_indices[
class_index][
ex_ind]])
for ex_ind, _ in
enumerate(
complementarity_examples[
class_index])]
_remove_available(available_init_indices,
complementarity_examples[
class_index],
class_index)
print(bad_views_indices)
```