Skip to content
Snippets Groups Projects
Commit 125817c1 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add back kmeans method in the trainer (disapeared during a merge)

parent 1db36b5d
No related branches found
No related tags found
2 merge requests!20Resolve "integration-sota",!19WIP: Resolve "Adding new datasets"
...@@ -96,7 +96,7 @@ class Trainer(object): ...@@ -96,7 +96,7 @@ class Trainer(object):
self._end_time = time.time() self._end_time = time.time()
def __score_func(self, model, X, y_true, weights=True): def __score_func(self, model, X, y_true, weights=True):
if type(model) in [OmpForestRegressor, RandomForestRegressor, SimilarityForestRegressor]: if type(model) in [OmpForestRegressor, RandomForestRegressor]:
if weights: if weights:
y_pred = model.predict(X) y_pred = model.predict(X)
else: else:
...@@ -109,12 +109,14 @@ class Trainer(object): ...@@ -109,12 +109,14 @@ class Trainer(object):
y_pred = model.predict_no_weights(X) y_pred = model.predict_no_weights(X)
if type(model) is OmpForestBinaryClassifier: if type(model) is OmpForestBinaryClassifier:
y_pred = np.sign(y_pred) y_pred = np.sign(y_pred)
y_pred = np.where(y_pred==0, 1, y_pred) y_pred = np.where(y_pred == 0, 1, y_pred)
result = self._classification_score_metric(y_true, y_pred) result = self._classification_score_metric(y_true, y_pred)
elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor]:
result = model.score(X, y_true)
return result return result
def __score_func_base(self, model, X, y_true): def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor: if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor]:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = self._base_regression_score_metric(y_true, y_pred) result = self._base_regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
...@@ -123,7 +125,7 @@ class Trainer(object): ...@@ -123,7 +125,7 @@ class Trainer(object):
elif type(model) == RandomForestClassifier: elif type(model) == RandomForestClassifier:
y_pred = model.predict(X) y_pred = model.predict(X)
result = self._base_classification_score_metric(y_true, y_pred) result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) in [RandomForestRegressor, SimilarityForestRegressor]: elif type(model) is RandomForestRegressor:
y_pred = model.predict(X) y_pred = model.predict(X)
result = self._base_regression_score_metric(y_true, y_pred) result = self._base_regression_score_metric(y_true, y_pred)
return result return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment