Skip to content
Snippets Groups Projects
Commit 53a9edee authored by Denis Arrivault's avatar Denis Arrivault
Browse files

Add some tests

parent 65d21adb
Branches
No related tags found
No related merge requests found
Pipeline #
...@@ -87,3 +87,16 @@ scikit_gilearn.egg-info/ ...@@ -87,3 +87,16 @@ scikit_gilearn.egg-info/
scikit_splearn.egg-info/* scikit_splearn.egg-info/*
.idea/ .idea/
*.tar.gz *.tar.gz
.pytest_cache/
examples/simple_example-2.json.gv
examples/simple_example-2.json.gv.pdf
examples/simple_example-2.yaml.gv
examples/simple_example-2.yaml.gv.pdf
examples/simple_example.json
examples/simple_example.json.gv
examples/simple_example.json.gv.pdf
examples/simple_example.yaml
examples/simple_example.yaml.gv
examples/simple_example.yaml.gv.pdf
examples/simple_example_hankel.json
examples/simple_example_hankel.yaml
...@@ -7,11 +7,11 @@ RUN apt-get update && apt-get install -y \ ...@@ -7,11 +7,11 @@ RUN apt-get update && apt-get install -y \
python3-scipy \ python3-scipy \
graphviz-dev graphviz-dev
RUN pip3 install --upgrade pip RUN pip3 install --upgrade pip
RUN pip3 install pyyaml nose coverage sphinx sphinxcontrib-bibtex RUN pip3 install pyyaml nose coverage pytest pytest-coverage pytest-html sphinx sphinxcontrib-bibtex
# Copy the scikit-splearn sdist in the docker directory and uncomment the following line # Copy the scikit-splearn sdist in the docker directory and uncomment the following line
# if you want to include grackelpy sources in the docker image : # if you want to include grackelpy sources in the docker image :
# ADD scikit-splearn-1.1.0.tar.gz / ADD scikit-splearn-1.1.0.tar.gz /
# cleanup # cleanup
......
...@@ -5,18 +5,17 @@ with-coverage=1 ...@@ -5,18 +5,17 @@ with-coverage=1
cover-package=splearn cover-package=splearn
cover-erase=1 cover-erase=1
cover-html=1 cover-html=1
cover-html-dir=../htmlcov cover-html-dir=../build/htmlcov
# Options for py.test command # Options for py.test command
[pytest] [tool:pytest]
# Specifies a minimal pytest version required for running tests. # Specifies a minimal pytest version required for running tests.
minversion = 2.6 minversion = 3.0
# Specifies the options # Specifies the options
addopts = --resultlog=pytests_results.txt -k "not _old" --cov-report term-missing --cov=sksplearn addopts = --cov-config .coveragerc --html=build/pytest_report.html -k "not _old" --cov-report html:build/htmlcov --cov=splearn
# Set the directory basename patterns to avoid when recursing for test discovery. # Set the directory basename patterns to avoid when recursing for test discovery.
norecursedirs = .git sandboxes .settings .cache htmlcov doc references norecursedirs = .git sandboxes .settings .cache htmlcov doc references build
[coverage:run] [coverage:run]
source=./splearn source=./splearn
......
...@@ -148,7 +148,7 @@ def setup_package(): ...@@ -148,7 +148,7 @@ def setup_package():
read('HISTORY.rst') + '\n\n' + read('HISTORY.rst') + '\n\n' +
read('AUTHORS.rst')), read('AUTHORS.rst')),
packages=["splearn", "splearn.datasets", "splearn.tests", "splearn.tests.datasets"], packages=["splearn", "splearn.datasets", "splearn.tests", "splearn.tests.datasets"],
package_data={'splearn.tests.datasets': ['*.*']}, package_data={'splearn.tests.datasets': ['*']},
url="https://gitlab.lif.univ-mrs.fr/dominique.benielli/scikit-splearn.git", url="https://gitlab.lif.univ-mrs.fr/dominique.benielli/scikit-splearn.git",
license='new BSD', license='new BSD',
author='François Denis and Rémi Eyraud and Denis Arrivault and Dominique Benielli', author='François Denis and Rémi Eyraud and Denis Arrivault and Dominique Benielli',
......
...@@ -51,20 +51,14 @@ class Hankel(object): ...@@ -51,20 +51,14 @@ class Hankel(object):
>>> pT = load_data_sample(adr=train_file) >>> pT = load_data_sample(adr=train_file)
>>> sp = Spectral() >>> sp = Spectral()
>>> sp.fit(X=pT.data) >>> sp.fit(X=pT.data)
>>> lhankel = Hankel( sample=pT.sample, pref=pT.pref, >>> lhankel = Hankel( sample_instance=pT.sample,
>>> suff=pT.suff, fact=pT.fact,
>>> nbL=pT.nbL, nbEx=pT.nbEx, >>> nbL=pT.nbL, nbEx=pT.nbEx,
>>> lrows=6, lcolumns=6, version="classic", >>> lrows=6, lcolumns=6, version="classic",
>>> partial=True, sparse=True, mode_quiet=True).lhankel >>> partial=True, sparse=True, mode_quiet=True).lhankel
- Input: - Input:
:param dict sample_instance: sample dictionary :param Splearn_array sample_instance: instance of Splearn_array
:param dict pref: prefix dictionary
:param dict suff: suffix dictionary
:param dict fact: factor dictionary
:param int nbL: the number of letters
:param int nbS: the number of states
:param lrows: number or list of rows, :param lrows: number or list of rows,
a list of strings if partial=True; a list of strings if partial=True;
otherwise, based on self.pref if version="classic" or otherwise, based on self.pref if version="classic" or
......
...@@ -119,6 +119,10 @@ class Serializer(object): ...@@ -119,6 +119,10 @@ class Serializer(object):
@staticmethod @staticmethod
def __restore_yaml(data_str): def __restore_yaml(data_str):
if data_str is None or isinstance(data_str, (bool, int, float, str)):
return data_str
if isinstance(data_str, list):
return [Serializer.__restore_yaml(k) for k in data_str]
if "dict" in data_str: if "dict" in data_str:
return dict(data_str["dict"]) return dict(data_str["dict"])
if "tuple" in data_str: if "tuple" in data_str:
......
...@@ -35,12 +35,14 @@ ...@@ -35,12 +35,14 @@
# ######### COPYRIGHT ######### # ######### COPYRIGHT #########
import unittest import unittest
import numpy as np import numpy as np
import filecmp
import os import os
from splearn.automaton import Automaton from splearn.automaton import Automaton
from splearn.hankel import Hankel from splearn.hankel import Hankel
from splearn.serializer import Serializer
from splearn.spectral import Spectral
from splearn.tests.datasets.get_dataset_path import get_dataset_path from splearn.tests.datasets.get_dataset_path import get_dataset_path
from splearn.datasets.base import load_data_sample
class UnitaryTest(unittest.TestCase): class UnitaryTest(unittest.TestCase):
...@@ -62,8 +64,9 @@ class UnitaryTest(unittest.TestCase): ...@@ -62,8 +64,9 @@ class UnitaryTest(unittest.TestCase):
def testWriteAutomata(self): def testWriteAutomata(self):
for f in self.formats: for f in self.formats:
Automaton.write(self.A, get_dataset_path(self.input_file + '_2.' + f), format=f) Automaton.write(self.A, get_dataset_path(self.input_file + '_2.' + f), format=f)
self.assertTrue(filecmp.cmp(get_dataset_path(self.input_file + '_2.' + f), B = Automaton.read(get_dataset_path(self.input_file + '_2.' + f), format=f)
get_dataset_path(self.input_file + '.' + f))) for w in self.words:
np.testing.assert_almost_equal(self.A.val(w), B.val(w))
for f in self.formats: for f in self.formats:
os.remove(get_dataset_path(self.input_file + '_2.' + f)) os.remove(get_dataset_path(self.input_file + '_2.' + f))
...@@ -78,6 +81,38 @@ class UnitaryTest(unittest.TestCase): ...@@ -78,6 +81,38 @@ class UnitaryTest(unittest.TestCase):
for f in self.formats: for f in self.formats:
os.remove(get_dataset_path(self.input_file + "_hankel" + "." + f)) os.remove(get_dataset_path(self.input_file + "_hankel" + "." + f))
# def testReadWriteRealHankel(self):
# adr = get_dataset_path("3.pautomac.train")
# data = load_data_sample(adr=adr)
# X = data.data
# sp = Spectral()
# sp = sp.fit(X)
# H = Hankel( sample_instance=X.sample,
# lrows=6, lcolumns=6, version="classic",
# partial=True, sparse=True, mode_quiet=True)
# for f in self.formats:
# Hankel.write(H, get_dataset_path("3.pautomac.train" + "_hankel" + "." + f), format=f)
# Hb = Hankel.read(get_dataset_path("3.pautomac.train" + "_hankel" + "." + f), format = f)
# self.assertEqual(H, Hb)
# for f in self.formats:
# os.remove(get_dataset_path("3.pautomac.train" + "_hankel" + "." + f))
def testOthersSerializationTypes(self):
data = [{'a' : 10, 40 : 'gu'}, {'toto', 5, 2.5, 'b'}, ('gh', 25, 'ko', 1.0)]
data_json_str = Serializer.data_to_json(data)
data_yaml_str = Serializer.data_to_yaml(data)
data_json = Serializer.json_to_data(data_json_str)
data_yaml = Serializer.yaml_to_data(data_yaml_str)
self.assertEqual(data, data_json)
self.assertEqual(data, data_yaml)
data = [1, 2, 3.0]
data_json_str = Serializer.data_to_json(data)
data_yaml_str = Serializer.data_to_yaml(data)
data_json = Serializer.json_to_data(data_json_str)
data_yaml = Serializer.yaml_to_data(data_yaml_str)
self.assertEqual(data, data_json)
self.assertEqual(data, data_yaml)
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testName'] #import sys;sys.argv = ['', 'Test.testName']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment