Skip to content
Snippets Groups Projects
Commit 143b37a8 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Added init_hypotheses

parent 5dce3320
Branches
Tags
No related merge requests found
...@@ -51,23 +51,15 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -51,23 +51,15 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
if scipy.sparse.issparse(X): if scipy.sparse.issparse(X):
logging.info('Converting to dense matrix.') logging.info('Converting to dense matrix.')
X = np.array(X.todense()) X = np.array(X.todense())
if self.estimators_generator is None:
self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps,
self_complemented=self.self_complemented)
# Initialization # Initialization
y[y == 0] = -1 y[y == 0] = -1
y = y.reshape((y.shape[0], 1))
self.init_info_containers()
self.estimators_generator.fit(X, y)
self.classification_matrix = self._binary_classification_matrix(X)
self.init_info_containers()
m, n = self.classification_matrix.shape
y = y.reshape((m,1))
y_kernel_matrix = np.multiply(y, self.classification_matrix)
m,n,y_kernel_matrix = self.init_hypotheses(X, y)
self.example_weights = self._initialize_alphas(m).reshape((m,1)) self.example_weights = self._initialize_alphas(m).reshape((m,1))
...@@ -173,6 +165,17 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -173,6 +165,17 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
self.predict_time = end - start self.predict_time = end - start
return signs_array return signs_array
def init_hypotheses(self, X, y):
if self.estimators_generator is None:
self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps,
self_complemented=self.self_complemented)
self.estimators_generator.fit(X, y)
self.classification_matrix = self._binary_classification_matrix(X)
m, n = self.classification_matrix.shape
y_kernel_matrix = np.multiply(y, self.classification_matrix)
return m,n,y_kernel_matrix
def init_info_containers(self): def init_info_containers(self):
self.weights_ = [] self.weights_ = []
self.chosen_columns_ = [] self.chosen_columns_ = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment