Skip to content
Snippets Groups Projects
Commit 3a2ec5cb authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix score func in trainer

parent 6239030f
No related branches found
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -51,7 +51,8 @@ class Trainer(object): ...@@ -51,7 +51,8 @@ class Trainer(object):
def train(self, model): def train(self, model):
""" """
:param model: Object with :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return: :return:
""" """
...@@ -72,32 +73,28 @@ class Trainer(object): ...@@ -72,32 +73,28 @@ class Trainer(object):
self._end_time = time.time() self._end_time = time.time()
def __score_func(self, model, X, y_true): def __score_func(self, model, X, y_true):
if type(model) == OmpForestRegressor: if type(model) in [OmpForestRegressor, RandomForestRegressor]:
y_pred = model.predict(X) y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred) result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict(X) y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred) result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict(X)
result = model.score(y_true, y_pred)
return result return result
def __score_func_base(self, model, X, y_true): def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor: if type(model) == OmpForestRegressor:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = mean_squared_error(y_true, y_pred) result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = accuracy_score(y_true, y_pred) result = accuracy_score(y_true, y_pred)
elif type(model) == RandomForestClassifier:
else: y_pred = model.predict(X)
y_pred = model.predict_base_estimator(X) result = accuracy_score(y_true, y_pred)
result = model.score(y_true, y_pred) elif type(model) == RandomForestRegressor:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
return result return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment