From f9d5091451ce00bc83e8c51e35938fe74b4b5aba Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9o=20Bouscarrat?= <leo.bouscarrat@euranova.eu>
Date: Wed, 11 Mar 2020 17:16:01 +0100
Subject: [PATCH] Add test for base OmpForest class

---
 code/bolsonaro/models/model_parameters.py | 14 ++++++
 tests/test_bolsonaro.py                   | 58 +++++++++++++++++++++++
 2 files changed, 72 insertions(+)
 create mode 100644 tests/test_bolsonaro.py

diff --git a/code/bolsonaro/models/model_parameters.py b/code/bolsonaro/models/model_parameters.py
index a3286ed..2009190 100644
--- a/code/bolsonaro/models/model_parameters.py
+++ b/code/bolsonaro/models/model_parameters.py
@@ -7,6 +7,20 @@ class ModelParameters(object):
 
     def __init__(self, extracted_forest_size, normalize_D, subsets_used,
         normalize_weights, seed, hyperparameters, extraction_strategy):
+        """Init of ModelParameters.
+        
+        Args:
+            extracted_forest_size (list): list of all the extracted forest
+                size.
+            normalize_D (bool): true normalize the distribution, false no
+            subsets_used (list): which dataset use for randomForest and for OMP
+                'train', 'dev' or 'train+dev' and combination of two of this.
+            normalize_weights (bool): if we normalize the weights or no.
+            seed (int): the seed used for the randomization.
+            hyperparameters (dict): dict of the hyperparameters of RandomForest
+                in scikit-learn.
+            extraction_strategy (str): either 'none', 'random', 'omp'
+        """
         self._extracted_forest_size = extracted_forest_size
         self._normalize_D = normalize_D
         self._subsets_used = subsets_used
diff --git a/tests/test_bolsonaro.py b/tests/test_bolsonaro.py
new file mode 100644
index 0000000..e282f20
--- /dev/null
+++ b/tests/test_bolsonaro.py
@@ -0,0 +1,58 @@
+import numpy as np
+
+from bolsonaro.models.model_parameters import ModelParameters
+from bolsonaro.models.omp_forest_classifier import OmpForestBinaryClassifier, OmpForestMulticlassClassifier
+from bolsonaro.models.omp_forest_regressor import OmpForestRegressor
+
+
+def test_binary_classif_omp():
+
+    model_parameters = ModelParameters(
+        1, False, ['train+dev', 'train+dev'], False, 1,
+        {'n_estimators': 100}, 'omp'
+    )
+
+    omp_forest = OmpForestBinaryClassifier(model_parameters)
+    X_train = [[1, 0], [0, 1]]
+    y_train = [-1, 1]
+
+    omp_forest.fit(X_train, y_train, X_train, y_train)
+
+    results = omp_forest.predict(X_train)
+
+    assert isinstance(results, np.ndarray)
+
+
+def test_regression_omp():
+
+    model_parameters = ModelParameters(
+        1, False, ['train+dev', 'train+dev'], False, 1,
+        {'n_estimators': 100}, 'omp'
+    )
+
+    omp_forest = OmpForestRegressor(model_parameters)
+    X_train = [[1, 0], [0, 1]]
+    y_train = [-1, 1]
+
+    omp_forest.fit(X_train, y_train, X_train, y_train)
+
+    results = omp_forest.predict(X_train)
+
+    assert isinstance(results, np.ndarray)
+
+def test_multiclassif_omp():
+
+    model_parameters = ModelParameters(
+        1, False, ['train+dev', 'train+dev'], False, 1,
+        {'n_estimators': 100}, 'omp'
+    )
+
+    omp_forest = OmpForestMulticlassClassifier(model_parameters)
+    X_train = [[1, 0], [0, 1]]
+    y_train = [-1, 1]
+
+    omp_forest.fit(X_train, y_train, X_train, y_train)
+
+    results = omp_forest.predict(X_train)
+
+    assert isinstance(results, np.ndarray)
-- 
GitLab