Skip to content
Snippets Groups Projects

Resolve "non negative omp"

Merged Charly Lamothe requested to merge 24-non-negative-omp into master
20 files
+ 2261
201
Compare changes
  • Side-by-side
  • Inline
Files
20
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
 
from bolsonaro.models.nn_omp_forest_regressor import NonNegativeOmpForestRegressor
 
from bolsonaro.models.nn_omp_forest_classifier import NonNegativeOmpForestBinaryClassifier
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor, SimilarityForestClassifier
from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
from bolsonaro.models.kmeans_forest_regressor import KMeansForestRegressor, KMeansForestClassifier
@@ -19,8 +21,10 @@ class ModelFactory(object):
@@ -19,8 +21,10 @@ class ModelFactory(object):
raise ValueError("Unsupported task '{}'".format(task))
raise ValueError("Unsupported task '{}'".format(task))
if task == Task.BINARYCLASSIFICATION:
if task == Task.BINARYCLASSIFICATION:
if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
return OmpForestBinaryClassifier(model_parameters)
return OmpForestBinaryClassifier(model_parameters)
 
elif model_parameters.extraction_strategy == 'omp_nn':
 
return NonNegativeOmpForestBinaryClassifier(model_parameters)
elif model_parameters.extraction_strategy == 'random':
elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(**model_parameters.hyperparameters,
return RandomForestClassifier(**model_parameters.hyperparameters,
random_state=model_parameters.seed)
random_state=model_parameters.seed)
@@ -36,8 +40,10 @@ class ModelFactory(object):
@@ -36,8 +40,10 @@ class ModelFactory(object):
else:
else:
raise ValueError('Invalid extraction strategy')
raise ValueError('Invalid extraction strategy')
elif task == Task.REGRESSION:
elif task == Task.REGRESSION:
if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
return OmpForestRegressor(model_parameters)
return OmpForestRegressor(model_parameters)
 
elif model_parameters.extraction_strategy == 'omp_nn':
 
return NonNegativeOmpForestRegressor(model_parameters)
elif model_parameters.extraction_strategy == 'random':
elif model_parameters.extraction_strategy == 'random':
return RandomForestRegressor(**model_parameters.hyperparameters,
return RandomForestRegressor(**model_parameters.hyperparameters,
random_state=model_parameters.seed)
random_state=model_parameters.seed)
@@ -53,8 +59,10 @@ class ModelFactory(object):
@@ -53,8 +59,10 @@ class ModelFactory(object):
else:
else:
raise ValueError('Invalid extraction strategy')
raise ValueError('Invalid extraction strategy')
elif task == Task.MULTICLASSIFICATION:
elif task == Task.MULTICLASSIFICATION:
if model_parameters.extraction_strategy in ['omp', 'omp_nn', 'omp_distillation']:
if model_parameters.extraction_strategy in ['omp', 'omp_distillation']:
return OmpForestMulticlassClassifier(model_parameters)
return OmpForestMulticlassClassifier(model_parameters)
 
elif model_parameters.extraction_strategy == 'omp_nn':
 
raise ValueError('omp_nn is unsuported for multi classification')
elif model_parameters.extraction_strategy == 'random':
elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(**model_parameters.hyperparameters,
return RandomForestClassifier(**model_parameters.hyperparameters,
random_state=model_parameters.seed)
random_state=model_parameters.seed)
Loading