Skip to content
Snippets Groups Projects
Commit 58c55652 authored by Baptiste Bauvin's avatar Baptiste Bauvin
Browse files

Trees were different for each algo

parent 9b7f9916
No related branches found
No related tags found
No related merge requests found
...@@ -191,13 +191,16 @@ class ClassifiersGenerator(BaseEstimator, TransformerMixin): ...@@ -191,13 +191,16 @@ class ClassifiersGenerator(BaseEstimator, TransformerMixin):
class TreeClassifiersGenerator(ClassifiersGenerator): class TreeClassifiersGenerator(ClassifiersGenerator):
def __init__(self, random_state, max_depth=2, self_complemented=True, criterion="gini", splitter="best", n_trees=100, distribution_type="uniform", low=0, high=10, attributes_ratio=0.6, examples_ratio=0.95): def __init__(self, random_state=42, max_depth=2, self_complemented=True, criterion="gini", splitter="best", n_trees=100, distribution_type="uniform", low=0, high=10, attributes_ratio=0.6, examples_ratio=0.95):
super(TreeClassifiersGenerator, self).__init__(self_complemented) super(TreeClassifiersGenerator, self).__init__(self_complemented)
self.max_depth=max_depth self.max_depth=max_depth
self.criterion=criterion self.criterion=criterion
self.splitter=splitter self.splitter=splitter
self.n_trees=n_trees self.n_trees=n_trees
self.random_state=random_state if type(random_state) is int:
self.random_state = np.random.RandomState(random_state)
else:
self.random_state=random_state
self.distribution_type = distribution_type self.distribution_type = distribution_type
self.low = low self.low = low
self.high = high self.high = high
...@@ -208,6 +211,7 @@ class TreeClassifiersGenerator(ClassifiersGenerator): ...@@ -208,6 +211,7 @@ class TreeClassifiersGenerator(ClassifiersGenerator):
estimators_ = [] estimators_ = []
self.attribute_indices = [self.sub_sample_attributes(X) for _ in range(self.n_trees)] self.attribute_indices = [self.sub_sample_attributes(X) for _ in range(self.n_trees)]
self.example_indices = [self.sub_sample_examples(X) for _ in range(self.n_trees)] self.example_indices = [self.sub_sample_examples(X) for _ in range(self.n_trees)]
print(self.example_indices)
for i in range(self.n_trees): for i in range(self.n_trees):
estimators_.append(DecisionTreeClassifier(criterion=self.criterion, splitter=self.splitter, max_depth=self.max_depth).fit(X[:, self.attribute_indices[i]][self.example_indices[i], :], y[self.example_indices[i]])) estimators_.append(DecisionTreeClassifier(criterion=self.criterion, splitter=self.splitter, max_depth=self.max_depth).fit(X[:, self.attribute_indices[i]][self.example_indices[i], :], y[self.example_indices[i]]))
self.estimators_ = np.asarray(estimators_) self.estimators_ = np.asarray(estimators_)
......
...@@ -333,7 +333,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -333,7 +333,7 @@ class ColumnGenerationClassifierQar(BaseEstimator, ClassifierMixin, BaseBoost):
n_stumps_per_attribute=self.n_stumps, n_stumps_per_attribute=self.n_stumps,
self_complemented=self.self_complemented) self_complemented=self.self_complemented)
if self.estimators_generator is "Trees": if self.estimators_generator is "Trees":
self.estimators_generator = TreeClassifiersGenerator(self.random_state, n_trees=self.n_stumps, max_depth=self.max_depth, self.estimators_generator = TreeClassifiersGenerator(n_trees=self.n_stumps, max_depth=self.max_depth,
self_complemented=self.self_complemented) self_complemented=self.self_complemented)
self.estimators_generator.fit(X, y) self.estimators_generator.fit(X, y)
self.classification_matrix = self._binary_classification_matrix(X) self.classification_matrix = self._binary_classification_matrix(X)
......
...@@ -26,7 +26,6 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -26,7 +26,6 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
self.random_state = random_state self.random_state = random_state
def fit(self, X, y): def fit(self, X, y):
start = time.time()
if scipy.sparse.issparse(X): if scipy.sparse.issparse(X):
X = np.array(X.todense()) X = np.array(X.todense())
...@@ -35,8 +34,9 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -35,8 +34,9 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
if self.estimators_generator is "Stumps": if self.estimators_generator is "Stumps":
self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps, self_complemented=True) self.estimators_generator = StumpsClassifiersGenerator(n_stumps_per_attribute=self.n_stumps, self_complemented=True)
elif self.estimators_generator is "Trees": elif self.estimators_generator is "Trees":
self.estimators_generator = TreeClassifiersGenerator( self.random_state, max_depth=self.max_depth, n_trees=self.n_stumps, self_complemented=True) self.estimators_generator = TreeClassifiersGenerator(max_depth=self.max_depth, n_trees=self.n_stumps, self_complemented=True)
print(self.max_depth, self.n_stumps)
self.estimators_generator.fit(X, y) self.estimators_generator.fit(X, y)
self.classification_matrix = self._binary_classification_matrix(X) self.classification_matrix = self._binary_classification_matrix(X)
self.c_bounds = [] self.c_bounds = []
...@@ -63,6 +63,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost): ...@@ -63,6 +63,7 @@ class ColumnGenerationClassifier(BaseEstimator, ClassifierMixin, BaseBoost):
w= None w= None
self.collected_weight_vectors_ = {} self.collected_weight_vectors_ = {}
self.collected_dual_constraint_violations_ = {} self.collected_dual_constraint_violations_ = {}
start = time.time()
for k in range(min(n, self.n_max_iterations if self.n_max_iterations is not None else np.inf)): for k in range(min(n, self.n_max_iterations if self.n_max_iterations is not None else np.inf)):
# Find worst weak hypothesis given alpha. # Find worst weak hypothesis given alpha.
......
...@@ -13,7 +13,7 @@ class PregenClassifier(BaseBoost): ...@@ -13,7 +13,7 @@ class PregenClassifier(BaseBoost):
n_stumps_per_attribute=self.n_stumps, n_stumps_per_attribute=self.n_stumps,
self_complemented=self.self_complemented) self_complemented=self.self_complemented)
elif generator is "Trees": elif generator is "Trees":
self.estimators_generator = TreeClassifiersGenerator(self.random_state, n_trees=self.n_stumps, max_depth=self.max_depth) self.estimators_generator = TreeClassifiersGenerator(n_trees=self.n_stumps, max_depth=self.max_depth)
self.estimators_generator.fit(X, neg_y) self.estimators_generator.fit(X, neg_y)
else: else:
neg_y=None neg_y=None
......
...@@ -10,8 +10,7 @@ class MinCQGraalpyTree(RegularizedBinaryMinCqClassifier, BaseMonoviewClassifier) ...@@ -10,8 +10,7 @@ class MinCQGraalpyTree(RegularizedBinaryMinCqClassifier, BaseMonoviewClassifier)
def __init__(self, random_state=None, mu=0.01, self_complemented=True, n_stumps_per_attribute=1, max_depth=2, **kwargs): def __init__(self, random_state=None, mu=0.01, self_complemented=True, n_stumps_per_attribute=1, max_depth=2, **kwargs):
super(MinCQGraalpyTree, self).__init__(mu=mu, super(MinCQGraalpyTree, self).__init__(mu=mu,
estimators_generator=TreeClassifiersGenerator(random_state=random_state, estimators_generator=TreeClassifiersGenerator(n_trees=n_stumps_per_attribute,
n_trees=n_stumps_per_attribute,
max_depth=max_depth, max_depth=max_depth,
self_complemented=self_complemented), self_complemented=self_complemented),
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment