Skip to content
Snippets Groups Projects
Select Git revision
  • fcab676c7fa5928038ed57c8780e1e4eb2768b29
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

CQBoost.py

Blame
  • model_factory.py 3.93 KiB
    from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
    from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
    from bolsonaro.models.model_parameters import ModelParameters
    from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
    from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
    from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor, EnsembleSelectionForestClassifier
    from bolsonaro.data.task import Task
    
    from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
    import os
    import pickle
    
    
    class ModelFactory(object):
    
        @staticmethod
        def build(task, model_parameters):
            if task not in [Task.BINARYCLASSIFICATION, Task.REGRESSION, Task.MULTICLASSIFICATION]:
                raise ValueError("Unsupported task '{}'".format(task))
    
            if task == Task.BINARYCLASSIFICATION:
                if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                    return OmpForestBinaryClassifier(model_parameters)
                elif model_parameters.extraction_strategy == 'random':
                    return RandomForestClassifier(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                elif model_parameters.extraction_strategy == 'none':
                    return RandomForestClassifier(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                elif model_parameters.extraction_strategy == 'ensemble':
                    return EnsembleSelectionForestClassifier(model_parameters)
                elif model_parameters.extraction_strategy == 'kmeans':
                    return KMeansForestClassifier(model_parameters)
                elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
                    return SimilarityForestClassifier(model_parameters)
                else:
                    raise ValueError('Invalid extraction strategy')
            elif task == Task.REGRESSION:
                if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                    return OmpForestRegressor(model_parameters)
                elif model_parameters.extraction_strategy == 'random':
                    return RandomForestRegressor(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                elif model_parameters.extraction_strategy in ['similarity_similarities', 'similarity_predictions']:
                    return SimilarityForestRegressor(model_parameters)
                elif model_parameters.extraction_strategy == 'kmeans':
                    return KMeansForestRegressor(model_parameters)
                elif model_parameters.extraction_strategy == 'ensemble':
                    return EnsembleSelectionForestRegressor(model_parameters)
                elif model_parameters.extraction_strategy == 'none':
                    return RandomForestRegressor(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                else:
                    raise ValueError('Invalid extraction strategy')
            elif task == Task.MULTICLASSIFICATION:
                if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
                    return OmpForestMulticlassClassifier(model_parameters)
                elif model_parameters.extraction_strategy == 'random':
                    return RandomForestClassifier(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                elif model_parameters.extraction_strategy == 'none':
                    return RandomForestClassifier(**model_parameters.hyperparameters,
                        random_state=model_parameters.seed)
                else:
                    raise ValueError('Invalid extraction strategy')