diff --git a/code/bolsonaro/trainer.py b/code/bolsonaro/trainer.py index 78f2c082e4a9c20dfe7b6b5dfa2d5d49aca99cc2..b19569cdeae52d5e1d36c869a49aa2c56a93b150 100644 --- a/code/bolsonaro/trainer.py +++ b/code/bolsonaro/trainer.py @@ -167,32 +167,6 @@ class Trainer(object): :param model: Object with :param models_dir: Where the results will be saved """ - - model_weights = '' - if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]: - model_weights = model._omp.coef_ - elif type(model) == OmpForestMulticlassClassifier: - model_weights = model._dct_class_omp - elif type(model) == OmpForestBinaryClassifier: - model_weights = model._omp - - if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, - SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: - selected_trees = model.selected_trees - elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]: - selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0] - elif type(model) in [RandomForestRegressor, RandomForestClassifier]: - selected_trees = model.estimators_ - - if len(selected_trees) > 0: - target_selected_tree = int(os.path.split(models_dir)[-1]) - if target_selected_tree != len(selected_trees): - raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}') - with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file: - pickle.dump(selected_trees, output_file) - - strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric - # Reeeally dirty to put that here but otherwise it's not thread safe... if type(model) in [RandomForestRegressor, RandomForestClassifier]: if subsets_used == 'train,dev': @@ -221,6 +195,38 @@ class Trainer(object): else: raise ValueError("Unknown specified subsets_used parameter '{}'".format(model.models_parameters.subsets_used)) + model_weights = '' + if type(model) in [OmpForestRegressor, OmpForestBinaryClassifier]: + model_weights = model._omp.coef_ + elif type(model) == OmpForestMulticlassClassifier: + model_weights = model._dct_class_omp + elif type(model) == OmpForestBinaryClassifier: + model_weights = model._omp + + if type(model) in [SimilarityForestRegressor, KMeansForestRegressor, EnsembleSelectionForestRegressor, + SimilarityForestClassifier, KMeansForestClassifier, EnsembleSelectionForestClassifier]: + selected_trees = model.selected_trees + elif type(model) in [OmpForestRegressor, OmpForestMulticlassClassifier, OmpForestBinaryClassifier]: + selected_trees = np.asarray(model.forest)[model._omp.coef_ != 0] + elif type(model) in [RandomForestRegressor, RandomForestClassifier]: + selected_trees = model.estimators_ + + if len(selected_trees) > 0: + target_selected_tree = int(os.path.split(models_dir)[-1]) + if target_selected_tree != len(selected_trees): + predictions_X_omp = model.predict(X_omp) + error_prediction = np.linalg.norm(predictions_X_omp - y_omp) + if not np.isclose(error_prediction, 0): + raise ValueError(f'Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}') + else: + self._logger.warning(f"Invalid selected tree number target_selected_tree:{target_selected_tree} - len(selected_trees):{len(selected_trees)}" + " But the prediction is perfect on X_omp. Keep less trees.") + with open(os.path.join(models_dir, 'selected_trees.pickle'), 'wb') as output_file: + pickle.dump(selected_trees, output_file) + + strength_metric = self._regression_score_metric if self._dataset.task == Task.REGRESSION else self._classification_score_metric + + results = ModelRawResults( model_weights=model_weights, training_time=self._end_time - self._begin_time,