Skip to content
Snippets Groups Projects
Commit 125817c1 authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Add back kmeans method in the trainer (disapeared during a merge)

parent 1db36b5d
No related branches found
No related tags found
2 merge requests!20Resolve "integration-sota",!19WIP: Resolve "Adding new datasets"
......@@ -96,7 +96,7 @@ class Trainer(object):
self._end_time = time.time()
def __score_func(self, model, X, y_true, weights=True):
if type(model) in [OmpForestRegressor, RandomForestRegressor, SimilarityForestRegressor]:
if type(model) in [OmpForestRegressor, RandomForestRegressor]:
if weights:
y_pred = model.predict(X)
else:
......@@ -109,12 +109,14 @@ class Trainer(object):
y_pred = model.predict_no_weights(X)
if type(model) is OmpForestBinaryClassifier:
y_pred = np.sign(y_pred)
y_pred = np.where(y_pred==0, 1, y_pred)
y_pred = np.where(y_pred == 0, 1, y_pred)
result = self._classification_score_metric(y_true, y_pred)
elif type(model) in [SimilarityForestRegressor, KMeansForestRegressor]:
result = model.score(X, y_true)
return result
def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor:
if type(model) in [OmpForestRegressor, SimilarityForestRegressor, KMeansForestRegressor]:
y_pred = model.predict_base_estimator(X)
result = self._base_regression_score_metric(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
......@@ -123,7 +125,7 @@ class Trainer(object):
elif type(model) == RandomForestClassifier:
y_pred = model.predict(X)
result = self._base_classification_score_metric(y_true, y_pred)
elif type(model) in [RandomForestRegressor, SimilarityForestRegressor]:
elif type(model) is RandomForestRegressor:
y_pred = model.predict(X)
result = self._base_regression_score_metric(y_true, y_pred)
return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment