Skip to content
Snippets Groups Projects

Resolve "Experiment pipeline"

Merged Charly Lamothe requested to merge 12-experiment-pipeline into master
11 files
+ 190
69
Compare changes
  • Side-by-side
  • Inline
Files
11
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
from bolsonaro.data.task import Task
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.models.similarity_forest_regressor import SimilarityForestRegressor
from bolsonaro.data.task import Task
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
import os
import pickle
@@ -11,22 +13,35 @@ class ModelFactory(object):
@staticmethod
def build(task, model_parameters):
if task not in [Task.BINARYCLASSIFICATION, Task.REGRESSION, Task.MULTICLASSIFICATION]:
raise ValueError("Unsupported task '{}'".format(task))
if task == Task.BINARYCLASSIFICATION:
model_func = OmpForestBinaryClassifier
if model_parameters.extraction_strategy == 'omp':
return OmpForestBinaryClassifier(model_parameters)
elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
random_state=model_parameters.seed)
else:
return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
random_state=model_parameters.seed)
elif task == Task.REGRESSION:
model_func = OmpForestRegressor
if model_parameters.extraction_strategy == 'omp':
return OmpForestRegressor(model_parameters)
elif model_parameters.extraction_strategy == 'random':
return RandomForestRegressor(n_estimators=model_parameters.extracted_forest_size,
random_state=model_parameters.seed)
elif model_parameters.extraction_strategy == 'similarity':
return SimilarityForestRegressor(model_parameters)
else:
return RandomForestRegressor(n_estimators=model_parameters.hyperparameters['n_estimators'],
random_state=model_parameters.seed)
elif task == Task.MULTICLASSIFICATION:
model_func = OmpForestMulticlassClassifier
if model_parameters.extraction_strategy == 'omp':
return OmpForestMulticlassClassifier(model_parameters)
elif model_parameters.extraction_strategy == 'random':
return RandomForestClassifier(n_estimators=model_parameters.extracted_forest_size,
random_state=model_parameters.seed)
else:
raise ValueError("Unsupported task '{}'".format(task))
return model_func(model_parameters)
@staticmethod
def load(task, directory_path, experiment_id, model_raw_results):
raise NotImplementedError
model_parameters = ModelParameters.load(directory_path, experiment_id)
model = ModelFactory.build(task, model_parameters)
# todo faire ce qu'il faut ici pour rétablir correctement le modèle
model.set_forest(model_raw_results.model_object.forest)
model.set_weights(model_raw_results.model_object.weights)
return model
return RandomForestClassifier(n_estimators=model_parameters.hyperparameters['n_estimators'],
random_state=model_parameters.seed)
Loading