Skip to content
Snippets Groups Projects
Commit 3a2ec5cb authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Fix score func in trainer

parent 6239030f
Branches
No related tags found
1 merge request!9Resolve "Experiment pipeline"
...@@ -51,7 +51,8 @@ class Trainer(object): ...@@ -51,7 +51,8 @@ class Trainer(object):
def train(self, model): def train(self, model):
""" """
:param model: Object with :param model: An instance of either RandomForestRegressor, RandomForestClassifier, OmpForestRegressor,
OmpForestBinaryClassifier, OmpForestMulticlassClassifier.
:return: :return:
""" """
...@@ -72,32 +73,28 @@ class Trainer(object): ...@@ -72,32 +73,28 @@ class Trainer(object):
self._end_time = time.time() self._end_time = time.time()
def __score_func(self, model, X, y_true): def __score_func(self, model, X, y_true):
if type(model) == OmpForestRegressor: if type(model) in [OmpForestRegressor, RandomForestRegressor]:
y_pred = model.predict(X) y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred) result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier, RandomForestClassifier]:
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict(X) y_pred = model.predict(X)
result = accuracy_score(y_true, y_pred) result = accuracy_score(y_true, y_pred)
else:
y_pred = model.predict(X)
result = model.score(y_true, y_pred)
return result return result
def __score_func_base(self, model, X, y_true): def __score_func_base(self, model, X, y_true):
if type(model) == OmpForestRegressor: if type(model) == OmpForestRegressor:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = mean_squared_error(y_true, y_pred) result = mean_squared_error(y_true, y_pred)
elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]: elif type(model) in [OmpForestBinaryClassifier, OmpForestMulticlassClassifier]:
y_pred = model.predict_base_estimator(X) y_pred = model.predict_base_estimator(X)
result = accuracy_score(y_true, y_pred) result = accuracy_score(y_true, y_pred)
elif type(model) == RandomForestClassifier:
else: y_pred = model.predict(X)
y_pred = model.predict_base_estimator(X) result = accuracy_score(y_true, y_pred)
result = model.score(y_true, y_pred) elif type(model) == RandomForestRegressor:
y_pred = model.predict(X)
result = mean_squared_error(y_true, y_pred)
return result return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment