Skip to content
Snippets Groups Projects
Commit 3a2ec5cb authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix score func in trainer

parent 6239030f
Branches
No related tags found
1 merge request!9Resolve "Experiment pipeline"
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
......@@ -51,7 +51,8 @@ class Trainer(object):
def train(self, model):
"""
:param model: Object with
:param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return:
"""
......@@ -72,32 +73,28 @@ class Trainer(object):
self._end_time = time.time()
def __score_func(self, model, X, y_true):
if type(model) == OmpForestRegressor:
if type(model) in [OmpForestRegressor, RandomForestRegressor]:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict(X)
result = model.score(y_true, y_pred)
return result
def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor:
y_pred = model.predict_base_estimator(X)
result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict_base_estimator(X)
result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict_base_estimator(X)
result = model.score(y_true, y_pred)
elif type(model) == RandomForestClassifier:
y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred)
elif type(model) == RandomForestRegressor:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment