Skip to content
Snippets Groups Projects
Commit 1e4c3afe authored by Charly Lamothe's avatar Charly Lamothe
Browse files

Merge branch 'master' into 15-integration-sota

parents c9577c89 4e1e84c2
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,20 @@ class ModelParameters(object):
def __init__(self, extracted_forest_size, normalize_D, subsets_used,
normalize_weights, seed, hyperparameters, extraction_strategy):
"""Init of ModelParameters.
Args:
extracted_forest_size (list): list of all the extracted forest
size.
normalize_D (bool): true normalize the distribution, false no
subsets_used (list): which dataset use for randomForest and for OMP
'train', 'dev' or 'train+dev' and combination of two of this.
normalize_weights (bool): if we normalize the weights or no.
seed (int): the seed used for the randomization.
hyperparameters (dict): dict of the hyperparameters of RandomForest
in scikit-learn.
extraction_strategy (str): either 'none', 'random', 'omp'
"""
self._extracted_forest_size = extracted_forest_size
self._normalize_D = normalize_D
self._subsets_used = subsets_used
......
......@@ -7,7 +7,7 @@ Sphinx
coverage
awscli
flake8
pytest
scikit-learn
git+git://github.com/darenr/scikit-optimize@master
python-dotenv
......
import numpy as np
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
def test_binary_classif_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestBinaryClassifier(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
def test_regression_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestRegressor(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
def test_multiclassif_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestMulticlassClassifier(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment