Skip to content
Snippets Groups Projects
Commit 65b27dd5 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added check view in develop

parent f9f5d88b
No related branches found
No related tags found
No related merge requests found
Showing
with 75 additions and 62 deletions
......@@ -4,7 +4,7 @@ import numpy as np
from .. import monoview_classifiers
from ..utils.base import BaseClassifier, ResultAnalyser
from ..utils.dataset import RAMDataset
from ..utils.dataset import RAMDataset, get_examples_views_indices
class FakeEstimator():
......@@ -29,6 +29,7 @@ class BaseMultiviewClassifier(BaseClassifier):
self.random_state = random_state
self.short_name = self.__module__.split(".")[-1]
self.weird_strings = {}
self.used_views = None
@abstractmethod
def fit(self, X, y, train_indices=None, view_indices=None):
......@@ -38,6 +39,10 @@ class BaseMultiviewClassifier(BaseClassifier):
def predict(self, X, example_indices=None, view_indices=None):
pass
def _check_views(self, view_indices):
if self.used_views is not None and not np.array_equal(np.sort(self.used_views), np.sort(view_indices)):
raise ValueError('Used {} views to fit, and trying to predict on {}'.format(self.used_views, view_indices))
def to_str(self, param_name):
if param_name in self.weird_strings:
string = ""
......
......@@ -30,6 +30,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier,
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
self.used_views = view_indices
# TODO : Finer analysis, may support a bit of mutliclass
if np.unique(y[train_indices]).shape[0] > 2:
raise ValueError(
......@@ -56,6 +57,7 @@ class DiversityFusionClassifier(BaseMultiviewClassifier,
example_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
self._check_views(view_indices)
nb_class = X.get_nb_class()
if nb_class > 2:
nb_class = 3
......
......@@ -27,6 +27,7 @@ class BaseJumboFusion(LateFusionClassifier):
example_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
self._check_views(view_indices)
monoview_decisions = self.predict_monoview(X,
example_indices=example_indices,
view_indices=view_indices)
......@@ -36,6 +37,7 @@ class BaseJumboFusion(LateFusionClassifier):
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
self.used_views = view_indices
self.init_classifiers(len(view_indices),
nb_monoview_per_view=self.nb_monoview_per_view)
self.fit_monoview_estimators(X, y, train_indices=train_indices,
......
......@@ -97,6 +97,7 @@ class LateFusionClassifier(BaseMultiviewClassifier, BaseFusionClassifier):
train_indices, view_indices = get_examples_views_indices(X,
train_indices,
view_indices)
self.used_views = view_indices
if np.unique(y).shape[0] > 2:
multiclass = True
else:
......
......@@ -6,59 +6,59 @@ def get_names(classed_list):
return np.array([object_.__class__.__name__ for object_ in classed_list])
class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin):
# class BaseMultiviewClassifier(BaseEstimator, ClassifierMixin):
#
# def __init__(self, random_state):
# self.random_state = random_state
#
# def genBestParams(self, detector):
# return dict((param_name, detector.best_params_[param_name])
# for param_name in self.param_names)
#
# def genParamsFromDetector(self, detector):
# if self.classed_params:
# classed_dict = dict((classed_param, get_names(
# detector.cv_results_["param_" + classed_param]))
# for classed_param in self.classed_params)
# if self.param_names:
# return [(param_name,
# np.array(detector.cv_results_["param_" + param_name]))
# if param_name not in self.classed_params else (
# param_name, classed_dict[param_name])
# for param_name in self.param_names]
# else:
# return [()]
#
# def genDistribs(self):
# return dict((param_name, distrib) for param_name, distrib in
# zip(self.param_names, self.distribs))
#
# def getConfig(self):
# if self.param_names:
# return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join(
# [param_name + " : " + self.to_str(param_name) for param_name in
# self.param_names])
# else:
# return "\n\t\t- " + self.__class__.__name__ + "with no config."
#
# def to_str(self, param_name):
# if param_name in self.weird_strings:
# if self.weird_strings[param_name] == "class_name":
# return self.get_params()[param_name].__class__.__name__
# else:
# return self.weird_strings[param_name](
# self.get_params()[param_name])
# else:
# return str(self.get_params()[param_name])
#
# def get_interpretation(self):
# return "No detailed interpretation function"
def __init__(self, random_state):
self.random_state = random_state
def genBestParams(self, detector):
return dict((param_name, detector.best_params_[param_name])
for param_name in self.param_names)
def genParamsFromDetector(self, detector):
if self.classed_params:
classed_dict = dict((classed_param, get_names(
detector.cv_results_["param_" + classed_param]))
for classed_param in self.classed_params)
if self.param_names:
return [(param_name,
np.array(detector.cv_results_["param_" + param_name]))
if param_name not in self.classed_params else (
param_name, classed_dict[param_name])
for param_name in self.param_names]
else:
return [()]
def genDistribs(self):
return dict((param_name, distrib) for param_name, distrib in
zip(self.param_names, self.distribs))
def getConfig(self):
if self.param_names:
return "\n\t\t- " + self.__class__.__name__ + "with " + ", ".join(
[param_name + " : " + self.to_str(param_name) for param_name in
self.param_names])
else:
return "\n\t\t- " + self.__class__.__name__ + "with no config."
def to_str(self, param_name):
if param_name in self.weird_strings:
if self.weird_strings[param_name] == "class_name":
return self.get_params()[param_name].__class__.__name__
else:
return self.weird_strings[param_name](
self.get_params()[param_name])
else:
return str(self.get_params()[param_name])
def get_interpretation(self):
return "No detailed interpretation function"
def get_train_views_indices(dataset, train_indices, view_indices, ):
"""This function is used to get all the examples indices and view indices if needed"""
if view_indices is None:
view_indices = np.arange(dataset.nb_view)
if train_indices is None:
train_indices = range(dataset.get_nb_examples())
return train_indices, view_indices
#
# def get_train_views_indices(dataset, train_indices, view_indices, ):
# """This function is used to get all the examples indices and view indices if needed"""
# if view_indices is None:
# view_indices = np.arange(dataset.nb_view)
# if train_indices is None:
# train_indices = range(dataset.get_nb_examples())
# return train_indices, view_indices
......@@ -23,7 +23,7 @@ class BayesianInferenceClassifier(LateFusionClassifier):
example_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
self._check_views(view_indices)
if sum(self.weights) != 1.0:
self.weights = self.weights / sum(self.weights)
......
......@@ -23,16 +23,16 @@ class MajorityVoting(LateFusionClassifier):
rs=rs)
def predict(self, X, example_indices=None, view_indices=None):
examples_indices, views_indices = get_examples_views_indices(X,
examples_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
self._check_views(view_indices)
n_examples = len(examples_indices)
votes = np.zeros((n_examples, X.get_nb_class(example_indices)),
dtype=float)
monoview_decisions = np.zeros((len(examples_indices), X.nb_view),
dtype=int)
for index, view_index in enumerate(views_indices):
for index, view_index in enumerate(view_indices):
monoview_decisions[:, index] = self.monoview_estimators[
index].predict(
X.get_v(view_index, examples_indices))
......
......@@ -69,6 +69,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
def fit(self, X, y, train_indices=None, view_indices=None):
train_indices, X = self.transform_data_to_monoview(X, train_indices,
view_indices)
self.used_views = view_indices
if np.unique(y[train_indices]).shape[0] > 2 and \
not (isinstance(self.monoview_classifier, MultiClassWrapper)):
self.monoview_classifier = get_mc_estim(self.monoview_classifier,
......@@ -81,6 +82,7 @@ class WeightedLinearEarlyFusion(BaseMultiviewClassifier, BaseFusionClassifier):
def predict(self, X, example_indices=None, view_indices=None):
_, X = self.transform_data_to_monoview(X, example_indices, view_indices)
self._check_views(self.view_indices)
predicted_labels = self.monoview_classifier.predict(X)
return predicted_labels
......
......@@ -17,11 +17,12 @@ class WeightedLinearLateFusion(LateFusionClassifier):
nb_cores=nb_cores, weights=weights, rs=rs)
def predict(self, X, example_indices=None, view_indices=None):
example_indices, views_indices = get_examples_views_indices(X,
example_indices, view_indices = get_examples_views_indices(X,
example_indices,
view_indices)
self._check_views(view_indices)
view_scores = []
for index, viewIndex in enumerate(views_indices):
for index, viewIndex in enumerate(view_indices):
view_scores.append(
np.array(self.monoview_estimators[index].predict_proba(
X.get_v(viewIndex, example_indices))) * self.weights[index])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment