test_multiview_utils.py 3.02 KB
Newer Older
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os
import unittest

import h5py
import numpy as np
from sklearn.model_selection import StratifiedKFold

from summit.tests.utils import rm_tmp, tmp_path, test_dataset

from summit.multiview_platform.multiview import multiview_utils


class FakeMVClassif(multiview_utils.BaseMultiviewClassifier):

    def __init__(self, mc=True):
Baptiste Bauvin's avatar
Baptiste Bauvin committed
16
        self.mc = mc
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
17
18
19
20
21
22
23
24
25
26
27
28
29
        pass

    def fit(self, X, y):
        if not self.mc:
            raise ValueError
        else:
            pass


class TestBaseMultiviewClassifier(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
Baptiste Bauvin's avatar
Baptiste Bauvin committed
30
        rm_tmp()
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
31
32
33
34
35
36
37
38
39
40
41
42
        os.mkdir(tmp_path)

    @classmethod
    def tearDownClass(cls):
        rm_tmp()

    def test_accepts_multiclass(self):
        rs = np.random.RandomState(42)
        accepts = FakeMVClassif().accepts_multi_class(rs)
        self.assertEqual(accepts, True)
        accepts = FakeMVClassif(mc=False).accepts_multi_class(rs)
        self.assertEqual(accepts, False)
Baptiste Bauvin's avatar
Baptiste Bauvin committed
43
44
45
        self.assertRaises(ValueError, FakeMVClassif(
            mc=False).accepts_multi_class, rs, **{"n_samples": 2, "n_classes": 3})

Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
46
47
48
49
50
51
52
53

class TestConfigGenerator(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.rs = np.random.RandomState(42)

    def test_simple(self):
Baptiste Bauvin's avatar
Baptiste Bauvin committed
54
55
        cfg_gen = multiview_utils.ConfigGenerator(
            ["decision_tree", "decision_tree"])
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
56
57
        sample = cfg_gen.rvs(self.rs)
        self.assertEqual(sample, {'decision_tree': {'criterion': 'entropy',
Baptiste Bauvin's avatar
Baptiste Bauvin committed
58
59
60
                                                    'max_depth': 103,
                                                    'splitter': 'best'}})

Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
61
62
63
64
65

class TestFunctions(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
Baptiste Bauvin's avatar
Baptiste Bauvin committed
66
        rm_tmp()
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        os.mkdir(tmp_path)
        cls.rs = np.random.RandomState(42)

    @classmethod
    def tearDownClass(cls):
        rm_tmp()

    def test_get_available_monoview_classifiers(self):
        avail = multiview_utils.get_available_monoview_classifiers()
        self.assertEqual(avail, ['adaboost',
                                 'decision_tree',
                                 'gradient_boosting',
                                 'knn',
                                 'lasso',
                                 'random_forest',
Baptiste Bauvin's avatar
Baptiste Bauvin committed
82
                                 "random_scm",
Baptiste Bauvin's avatar
Baptiste Bauvin committed
83
                                 'scm',
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
84
85
86
87
                                 'sgd',
                                 'svm_linear',
                                 'svm_poly',
                                 'svm_rbf'])
Baptiste Bauvin's avatar
Baptiste Bauvin committed
88
89
        avail = multiview_utils.get_available_monoview_classifiers(
            need_probas=True)
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
90
91
92
93
94
        self.assertEqual(avail, ['adaboost',
                                 'decision_tree',
                                 'gradient_boosting',
                                 'knn',
                                 'random_forest',
Baptiste Bauvin's avatar
Baptiste Bauvin committed
95
                                 "random_scm",
Baptiste Bauvin's avatar
Baptiste Bauvin committed
96
                                 'scm',
Baptiste Bauvin's avatar
Renamed  
Baptiste Bauvin committed
97
98
99
                                 'svm_linear',
                                 'svm_poly',
                                 'svm_rbf'])