Skip to content
Snippets Groups Projects
Commit f9d50914 authored by Léo Bouscarrat's avatar Léo Bouscarrat
Browse files

Add test for base OmpForest class

parent a1a7f767
No related branches found
No related tags found
1 merge request!21Resolve "Add some tests"
......@@ -7,6 +7,20 @@ class ModelParameters(object):
def __init__(self, extracted_forest_size, normalize_D, subsets_used,
normalize_weights, seed, hyperparameters, extraction_strategy):
"""Init of ModelParameters.
Args:
extracted_forest_size (list): list of all the extracted forest
size.
normalize_D (bool): true normalize the distribution, false no
subsets_used (list): which dataset use for randomForest and for OMP
'train', 'dev' or 'train+dev' and combination of two of this.
normalize_weights (bool): if we normalize the weights or no.
seed (int): the seed used for the randomization.
hyperparameters (dict): dict of the hyperparameters of RandomForest
in scikit-learn.
extraction_strategy (str): either 'none', 'random', 'omp'
"""
self._extracted_forest_size = extracted_forest_size
self._normalize_D = normalize_D
self._subsets_used = subsets_used
......
import numpy as np
from bolsonaro.models.model_parameters import ModelParameters
from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
def test_binary_classif_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestBinaryClassifier(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
def test_regression_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestRegressor(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
def test_multiclassif_omp():
model_parameters = ModelParameters(
1, False, ['train+dev', 'train+dev'], False, 1,
{'n_estimators': 100}, 'omp'
)
omp_forest = OmpForestMulticlassClassifier(model_parameters)
X_train = [[1, 0], [0, 1]]
y_train = [-1, 1]
omp_forest.fit(X_train, y_train, X_train, y_train)
results = omp_forest.predict(X_train)
assert isinstance(results, np.ndarray)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment