Skip to content
Snippets Groups Projects
Commit d3f1d454 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Wrote weightedlinear early fusion, tests are passing.

parent 3786f3ed
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,8 @@ from ..monoview.monoview_utils import CustomRandint, BaseMonoviewClassifier
__author__ = "Baptiste Bauvin"
__status__ = "Prototype" # Production, Development, Prototype
classifier_class_name = "DecisionTree"
class DecisionTree(DecisionTreeClassifier, BaseMonoviewClassifier):
......
import logging
import math
import time
from collections import defaultdict
import numpy as np
import numpy.ma as ma
import scipy
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_is_fitted
from ... import metrics
class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, random_state):
self.random_state = random_state
def genBestParams(self, detector):
return dict((param_name, detector.best_params_[param_name])
for param_name in self.param_names)
def genParamsFromDetector(self, detector):
if self.classed_params:
classed_dict = dict((classed_param, get_names(
detector.cv_results_["param_" + classed_param]))
for classed_param in self.classed_params)
if self.param_names:
return [(param_name,
np.array(detector.cv_results_["param_" + param_name]))
if param_name not in self.classed_params else (
param_name, classed_dict[param_name])
for param_name in self.param_names]
else:
return [()]
def genDistribs(self):
return dict((param_name, distrib) for param_name, distrib in
zip(self.param_names, self.distribs))
def getConfig(self):
if self.param_names:
return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join(
[param_name + " : " + self.to_str(param_name) for param_name in
self.param_names])
else:
return "\n\t\t- " + self.__class__.__name__ + "with no config."
def to_str(self, param_name):
if param_name in self.weird_strings:
if self.weird_strings[param_name] == "class_name":
return self.get_params()[param_name].__class__.__name__
else:
return self.weird_strings[param_name](
self.get_params()[param_name])
else:
return str(self.get_params()[param_name])
def get_interpretation(self):
return "No detailed interpretation function"
def get_train_views_indices(dataset, train_indices, view_indices):
"""This function is used to get all the examples indices and view indices if needed"""
if view_indices is None:
view_indices = np.arange(dataset["Metadata"].attrs["nbView"])
if train_indices is None:
train_indices = range(dataset["Metadata"].attrs["datasetLength"])
return train_indices, view_indices
import numpy as np
from ..utils.dataset import getV
from .additions.utils import BaseMultiviewClassifier, get_train_views_indices
from .. import monoview_classifiers
class WeightedLinearEarlyFusion(BaseMultiviewClassifier):
def __init__(self, view_weights=None, monoview_classifier="decision_tree", monoview_classifier_config=None, random_state=42):
super(WeightedLinearEarlyFusion, self).__init__(random_state=random_state)
self.view_weights = np.array(view_weights)
if type(monoview_classifier) == str:
monoview_classifier_module = getattr(monoview_classifiers,
monoview_classifier)
monoview_classifier_class = getattr(monoview_classifier_module, monoview_classifier_module.classifier_class_name)
self.monoview_classifier = monoview_classifier_class(**monoview_classifier_config)
else:
self.monoview_classifier = monoview_classifier
def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, X = self.transform_data_to_monoview(X, train_indices, view_indices)
self.monoview_classifier.fit(X, y[train_indices])
def predict(self, X, predict_indices, view_indices):
_, X = self.transform_data_to_monoview(X, predict_indices, view_indices)
predicted_labels = self.monoview_classifier.predict(X)
return predicted_labels
def transform_data_to_monoview(self, dataset, example_indices, view_indices):
"""Here, we extract the data from the HDF5 dataset file and store all
the concatenated views in one variable"""
example_indices, self.view_indices = get_train_views_indices(dataset,
example_indices,
view_indices)
if self.view_weights is None:
self.view_weights = np.ones(len(self.view_indices), dtype=float)
self.view_weights /= float(np.sum(self.view_weights))
X = self.hdf5_to_monoview(dataset, example_indices, self.view_indices)
return example_indices, X
def hdf5_to_monoview(self, dataset, exmaples, view_indices):
"""Here, we concatenate the views for the asked examples """
monoview_data = np.concatenate(
[getV(dataset, view_idx, exmaples)
for view_weight, (index, view_idx)
in zip(self.view_weights, enumerate(view_indices))]
, axis=1)
return monoview_data
import unittest
import numpy as np
import h5py
import os
from multiview_platform.mono_multi_view_classifiers.multiview_classifiers import \
weighted_linear_early_fusion
class Test_WeightedLinearEarlyFusion(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.random_state = np.random.RandomState(42)
cls.view_weights = [0.5, 0.5]
os.mkdir("multiview_platform/tests/tmp_tests")
cls.dataset_file = h5py.File(
"multiview_platform/tests/tmp_tests/test_file.hdf5", "w")
cls.labels = cls.dataset_file.create_dataset("Labels",
data=np.array([0, 1, 0, 0, 1]))
cls.view0_data = cls.random_state.randint(1,10,size=(5, 4))
view0 = cls.dataset_file.create_dataset("View0", data=cls.view0_data)
view0.attrs["sparse"] = False
cls.view1_data = cls.random_state.randint(1, 10, size=(5, 4))
view1 = cls.dataset_file.create_dataset("View1", data=cls.view1_data)
view1.attrs["sparse"] = False
metaDataGrp = cls.dataset_file.create_group("Metadata")
metaDataGrp.attrs["nbView"] = 2
metaDataGrp.attrs["nbClass"] = 2
metaDataGrp.attrs["datasetLength"] = 5
cls.monoview_classifier_name = "decision_tree"
cls.monoview_classifier_config = {"max_depth":1, "criterion": "gini", "splitter": "best"}
cls.classifier = weighted_linear_early_fusion.WeightedLinearEarlyFusion(
cls.view_weights,
monoview_classifier=cls.monoview_classifier_name,
monoview_classifier_config=cls.monoview_classifier_config)
@classmethod
def tearDownClass(cls):
cls.dataset_file.close()
for file_name in os.listdir("multiview_platform/tests/tmp_tests"):
os.remove(os.path.join("multiview_platform/tests/tmp_tests", file_name))
os.rmdir("multiview_platform/tests/tmp_tests")
def test_simple(self):
np.testing.assert_array_equal(self.view_weights, self.classifier.view_weights)
def test_fit(self):
self.assertRaises(AttributeError, getattr,
self.classifier.monoview_classifier, "classes_")
self.classifier.fit(self.dataset_file, self.labels, None, None)
np.testing.assert_array_equal(self.classifier.monoview_classifier.classes_,
np.array([0,1]))
def test_predict(self):
self.classifier.fit(self.dataset_file, self.labels, None, None)
predicted_labels = self.classifier.predict(self.dataset_file, None, None)
np.testing.assert_array_equal(predicted_labels, self.labels)
def test_transform_data_to_monoview_simple(self):
example_indices, X = self.classifier.transform_data_to_monoview(self.dataset_file,
None, None)
self.assertEqual(X.shape, (5,8))
np.testing.assert_array_equal(X, np.concatenate((self.view0_data, self.view1_data), axis=1))
np.testing.assert_array_equal(example_indices, np.arange(5))
def test_transform_data_to_monoview_view_select(self):
example_indices, X = self.classifier.transform_data_to_monoview(
self.dataset_file,
None, np.array([0]))
self.assertEqual(X.shape, (5, 4))
np.testing.assert_array_equal(X, self.view0_data)
np.testing.assert_array_equal(example_indices, np.arange(5))
def test_transform_data_to_monoview_view_select(self):
example_indices, X = self.classifier.transform_data_to_monoview(
self.dataset_file,
np.array([1,2,3]), np.array([0]))
self.assertEqual(X.shape, (3, 4))
np.testing.assert_array_equal(X, self.view0_data[np.array([1,2,3]), :])
np.testing.assert_array_equal(example_indices, np.array([1,2,3]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment