Commit f9611920 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Modifs

parent 78dd68e7
......@@ -47,6 +47,9 @@ from sklearn.tree import DecisionTreeClassifier
from sklearn.tree.tree import BaseDecisionTree
from sklearn.tree._tree import DTYPE
from sklearn.ensemble.forest import BaseForest
from sklearn.base import clone
from sklearn.ensemble._base import _set_random_states
from sklearn.ensemble import BaseEnsemble
from multimodal.datasets.data_sample import DataSample
from multimodal.datasets.data_sample import MultiModalData, MultiModalArray, MultiModalSparseArray
......@@ -57,6 +60,27 @@ class UBoosting(metaclass=ABCMeta):
UBoosting for methods
"""
def _make_unique_estimator(self, base_estimator, estimator_params, append=True, random_state=None, ):
# Copy/Paste of sklearn.ensebmle.BaseEnsemble._make_estimator
estimator = clone(base_estimator)
estimator.set_params(**{p: getattr(self, p)
for p in estimator_params})
if random_state is not None:
_set_random_states(estimator, random_state)
if append:
self.estimators_.append(estimator)
return estimator
def _make_estimator(self, append=True, random_state=None, ind_view=0):
if type(self.base_estimator_) is list:
return self._make_unique_estimator(self.base_estimator[ind_view], self.estimator_params[ind_view], append=append, random_state=random_state)
else:
return self._make_unique_estimator(self.base_estimator, self.estimator_params, append=append, random_state=random_state)
def _validate_X_predict(self, X):
"""Ensure that X is in the proper format."""
if (self.base_estimator is None or
......
......@@ -46,8 +46,9 @@ estimator for classification implemented in the ``MumboClassifier`` class.
"""
import numpy as np
from sklearn.base import ClassifierMixin
from sklearn.base import ClassifierMixin, clone
from sklearn.ensemble import BaseEnsemble
from sklearn.ensemble._base import _set_random_states
from sklearn.ensemble.forest import BaseForest
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
......@@ -174,7 +175,13 @@ class MumboClassifier(BaseEnsemble, ClassifierMixin, UBoosting):
random_state=None,
best_view_mode="edge"):
super(MumboClassifier, self).__init__(
if type(base_estimator) is list:
self.base_estimator = base_estimator
self.n_estimators = n_estimators
self.estimator_params = [tuple() for _ in base_estimator]
else:
super(MumboClassifier, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators)
......@@ -185,10 +192,37 @@ class MumboClassifier(BaseEnsemble, ClassifierMixin, UBoosting):
"""Check the estimator and set the base_estimator_ attribute."""
super(MumboClassifier, self)._validate_estimator(
default=DecisionTreeClassifier(max_depth=1))
print()
if not has_fit_parameter(self.base_estimator_, "sample_weight"):
raise ValueError("%s doesn't support sample_weight."
% self.base_estimator_.__class__.__name__)
if type(self.base_estimator_) is list:
for estimator in self.base_estimator_:
if not has_fit_parameter(estimator, "sample_weight"):
raise ValueError("%s doesn't support sample_weight."
% estimator.__class__.__name__)
else:
if not has_fit_parameter(self.base_estimator_, "sample_weight"):
raise ValueError("%s doesn't support sample_weight."
% self.base_estimator_.__class__.__name__)
def _make_estimator(self, append=True, random_state=None, ind_view=0):
if type(self.base_estimator_) is list:
estimator = clone(self.base_estimator_[ind_view])
estimator.set_params(**{p: getattr(self, p)
for p in self.estimator_params[ind_view]})
# TODO : modify estimator_params to be able to set a list
if random_state is not None:
_set_random_states(estimator, random_state)
if append:
self.estimators_.append(estimator)
return estimator
else:
return super(MumboClassifier, self)._make_estimator(append=append,
random_state=random_state)
def _validate_best_view_mode(self, best_view_mode):
"""Ensure that best_view_mode has a proper value."""
......@@ -408,21 +442,24 @@ class MumboClassifier(BaseEnsemble, ClassifierMixin, UBoosting):
dist = self._compute_dist(cost, y)
for ind_view in range(n_views):
estimator = self._make_estimator(append=False,
random_state=random_state)
random_state=random_state, ind_view=ind_view)
estimator.fit(self.X_._extract_view(ind_view), y,
sample_weight=dist[ind_view, :])
estimators.append(estimator)
predicted_classes[ind_view, :] = estimator.predict(
self.X_._extract_view(ind_view))
edges = self._compute_edge_global(
cost_global, predicted_classes, y)
cost_global, predicted_classes, y)
print(cost_global)
print(edges)
print(np.unique(predicted_classes), np.unique(y))
print(np.sum(predicted_classes != y, axis=1))
if self.best_view_mode == "edge":
best_view = np.argmax(edges)
else: # self.best_view_mode == "error"
n_errors = np.sum(predicted_classes != y, axis=1)
best_view = np.argmin(n_errors)
print("Best view:", best_view)
edge = edges[best_view]
if (edge == 1.):
......@@ -436,8 +473,8 @@ class MumboClassifier(BaseEnsemble, ClassifierMixin, UBoosting):
break
self.estimator_errors_[current_iteration] = (
np.average(cost_global[np.arange(y.shape[0]), y])
* (-1. / (self.n_classes_-1)))
np.average(cost_global[np.arange(y.shape[0]), y])
* (-1. / (self.n_classes_-1)))
alpha = self._compute_alphas(edge)
self.estimator_weights_[current_iteration] = alpha
......
......@@ -58,7 +58,7 @@ from sklearn import datasets
from multimodal.boosting.mumbo import MumboClassifier
class TestMuCumboClassifier(unittest.TestCase):
class TestMumboClassifier(unittest.TestCase):
@classmethod
def setUpClass(clf):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment