model_factory.py 3.42 KB
Newer Older
1
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
2
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
3
from bolsonaro.models.model_parameters import ModelParameters
4
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
5
from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor
6
from bolsonaro.models.ensemble_selection_forest_regressor import EnsembleSelectionForestRegressor
7
from bolsonaro.data.task import Task
8

9
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
10
11
import os
import pickle
12
13
14
15
16


class ModelFactory(object):

    @staticmethod
17
    def build(task, model_parameters, library=None):
18
19
20
        if task not in [Task.BINARYCLASSIFICATION, Task.REGRESSION, Task.MULTICLASSIFICATION]:
            raise ValueError("Unsupported task '{}'".format(task))

21
        if task == Task.BINARYCLASSIFICATION:
22
23
24
25
26
            if model_parameters.extraction_strategy == 'omp':
                return OmpForestBinaryClassifier(model_parameters)
            elif model_parameters.extraction_strategy == 'random':
                return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
                    random_state=model_parameters.seed)
27
            elif model_parameters.extraction_strategy == 'none':
28
29
                return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
                    random_state=model_parameters.seed)
30
31
            else:
                raise ValueError('Invalid extraction strategy')
32
        elif task == Task.REGRESSION:
33
34
35
36
37
            if model_parameters.extraction_strategy == 'omp':
                return OmpForestRegressor(model_parameters)
            elif model_parameters.extraction_strategy == 'random':
                return RandomForestRegressor(n_estimators=model_parameters.extracted_forest_size,
                    random_state=model_parameters.seed)
38
39
            elif model_parameters.extraction_strategy == 'similarity':
                return SimilarityForestRegressor(model_parameters)
40
41
            elif model_parameters.extraction_strategy == 'kmeans':
                return KMeansForestRegressor(model_parameters)
42
43
            elif model_parameters.extraction_strategy == 'ensemble':
                return EnsembleSelectionForestRegressor(model_parameters, library=library)
44
            elif model_parameters.extraction_strategy == 'none':
45
46
                return RandomForestRegressor(n_estimators=model_parameters.hyperparameters['n_estimators'],
                    random_state=model_parameters.seed)
47
48
            else:
                raise ValueError('Invalid extraction strategy')
49
        elif task == Task.MULTICLASSIFICATION:
50
51
52
53
54
            if model_parameters.extraction_strategy == 'omp':
                return OmpForestMulticlassClassifier(model_parameters)
            elif model_parameters.extraction_strategy == 'random':
                return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
                    random_state=model_parameters.seed)
55
            elif model_parameters.extraction_strategy == 'none':
56
57
                return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
                    random_state=model_parameters.seed)
58
59
            else:
                raise ValueError('Invalid extraction strategy')