Skip to content
Snippets Groups Projects
Commit e207ec6f authored by Luc Giffon's avatar Luc Giffon
Browse files

fix bug predict

parent e045a789
No related branches found
No related tags found
1 merge request!24Resolve "non negative omp"
......@@ -167,32 +167,6 @@ class Trainer(object):
:param model: Object with
:param models_dir: Where the results will be saved
"""
model_weights = ''
if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
model_weights = model._omp.coef_
elif type(model) == OmpForestMulticlassClassifier:
model_weights = model._dct_class_omp
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
selected_trees = model.selected_trees
elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
selected_trees = model.estimators_
if len(selected_trees) > 0:
target_selected_tree = int(os.path.split(models_dir)[-1])
if target_selected_tree != len(selected_trees):
raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
pickle.dump(selected_trees, output_file)
strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric
# Reeeally dirty to put that here but otherwise it's not thread safe...
if type(model) in [RandomForestRegressor, RandomForestClassifier]:
if subsets_used == 'train,dev':
......@@ -221,6 +195,38 @@ class Trainer(object):
else:
raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used))
model_weights = ''
if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]:
model_weights = model._omp.coef_
elif type(model) == OmpForestMulticlassClassifier:
model_weights = model._dct_class_omp
elif type(model) == OmpForestBinaryClassifier:
model_weights = model._omp
if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor,
SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]:
selected_trees = model.selected_trees
elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]:
selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0]
elif type(model) in [RandomForestRegressor, RandomForestClassifier]:
selected_trees = model.estimators_
if len(selected_trees) > 0:
target_selected_tree = int(os.path.split(models_dir)[-1])
if target_selected_tree != len(selected_trees):
predictions_X_omp = model.predict(X_omp)
error_prediction = np.linalg.norm(predictions_X_omp - y_omp)
if not np.isclose(error_prediction, 0):
raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}')
else:
self._logger.warning(f"Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}"
" But the prediction is perfect on X_omp. Keep less trees.")
with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file:
pickle.dump(selected_trees, output_file)
strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric
results = ModelRawResults(
model_weights=model_weights,
training_time=self._end_time - self._begin_time,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment