From f9d5091451ce00bc83e8c51e35938fe74b4b5aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu> Date: Wed, 11 Mar 2020 17:16:01 +0100 Subject: [PATCH] Add test for base OmpForest class --- code/bolsonaro/models/model_parameters.py | 14 ++++++ tests/test_bolsonaro.py | 58 +++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/test_bolsonaro.py diff --git a/code/bolsonaro/models/model_parameters.py b/code/bolsonaro/models/model_parameters.py index a3286ed..2009190 100644 --- a/code/bolsonaro/models/model_parameters.py +++ b/code/bolsonaro/models/model_parameters.py @@ -7,6 +7,20 @@ class ModelParameters(object): def __init__(self, extracted_forest_size, normalize_D, subsets_used, normalize_weights, seed, hyperparameters, extraction_strategy): + """Init of ModelParameters. + + Args: + extracted_forest_size (list): list of all the extracted forest + size. + normalize_D (bool): true normalize the distribution, false no + subsets_used (list): which dataset use for randomForest and for OMP + 'train', 'dev' or 'train+dev' and combination of two of this. + normalize_weights (bool): if we normalize the weights or no. + seed (int): the seed used for the randomization. + hyperparameters (dict): dict of the hyperparameters of RandomForest + in scikit-learn. + extraction_strategy (str): either 'none', 'random', 'omp' + """ self._extracted_forest_size = extracted_forest_size self._normalize_D = normalize_D self._subsets_used = subsets_used diff --git a/tests/test_bolsonaro.py b/tests/test_bolsonaro.py new file mode 100644 index 0000000..e282f20 --- /dev/null +++ b/tests/test_bolsonaro.py @@ -0,0 +1,58 @@ +import numpy as np + +from bolsonaro.models.model_parameters import ModelParameters +from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier +from bolsonaro.models.omp_forest_regressor import OmpForestRegressor + + +def test_binary_classif_omp(): + + model_parameters = ModelParameters( + 1, False, ['train+dev', 'train+dev'], False, 1, + {'n_estimators': 100}, 'omp' + ) + + omp_forest = OmpForestBinaryClassifier(model_parameters) + X_train = [[1, 0], [0, 1]] + y_train = [-1, 1] + + omp_forest.fit(X_train, y_train, X_train, y_train) + + results = omp_forest.predict(X_train) + + assert isinstance(results, np.ndarray) + + +def test_regression_omp(): + + model_parameters = ModelParameters( + 1, False, ['train+dev', 'train+dev'], False, 1, + {'n_estimators': 100}, 'omp' + ) + + omp_forest = OmpForestRegressor(model_parameters) + X_train = [[1, 0], [0, 1]] + y_train = [-1, 1] + + omp_forest.fit(X_train, y_train, X_train, y_train) + + results = omp_forest.predict(X_train) + + assert isinstance(results, np.ndarray) + +def test_multiclassif_omp(): + + model_parameters = ModelParameters( + 1, False, ['train+dev', 'train+dev'], False, 1, + {'n_estimators': 100}, 'omp' + ) + + omp_forest = OmpForestMulticlassClassifier(model_parameters) + X_train = [[1, 0], [0, 1]] + y_train = [-1, 1] + + omp_forest.fit(X_train, y_train, X_train, y_train) + + results = omp_forest.predict(X_train) + + assert isinstance(results, np.ndarray) -- GitLab