Skip to content
Snippets Groups Projects
Commit 5e01b527 authored by Raphael's avatar Raphael
Browse files

improvements to test engine + tests

parent 243e8f69
No related branches found
No related tags found
2 merge requests!12version 0.2a,!10Resolve "Image creation bugs with 0 size windows"
import tqdm as tqdm import tqdm as tqdm
from skais.process.data_augmentation.data_transformer import DataTransformer
from skais.process.data_augmentation.flip import Flip from skais.process.data_augmentation.flip import Flip
from skais.process.data_augmentation.pipeline import Pipeline from skais.process.data_augmentation.pipeline import Pipeline
from skais.process.data_augmentation.translator import Translator from skais.process.data_augmentation.translator import Translator
class AugmentationEngine: class AugmentationEngine:
def __init__(self, translation_values, flip_values): def __init__(self, translation_values=None, flip_values=None, keep_original=True):
self.pipelines = [] self.pipelines = []
if keep_original:
self.pipelines.append(DataTransformer())
if translation_values is not None:
for tv_long, tv_lat in translation_values: for tv_long, tv_lat in translation_values:
self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)]))
if flip_values is not None:
for fv_meridian, fv_parallel in flip_values: for fv_meridian, fv_parallel in flip_values:
self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)]))
if flip_values is not None and translation_values is not None:
for tv_long, tv_lat in translation_values: for tv_long, tv_lat in translation_values:
translator = Translator(tv_long, tv_lat) translator = Translator(tv_long, tv_lat)
for fv_meridian, fv_parallel in flip_values: for fv_meridian, fv_parallel in flip_values:
......
class DataTransformer: class DataTransformer:
def transform(self, X): def transform(self, x):
pass return x
import unittest
from skais.ais.ais_trajectory import AISTrajectory
from skais.process.data_augmentation.augmentation_engine import AugmentationEngine
import pandas as pd
class Test_Engine(unittest.TestCase):
def setUp(self):
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
self.trajectories = [t1, t2]
def test_transform_simple_translation(self):
engine = AugmentationEngine(translation_values=[(10, 0), (20, 0)], keep_original=False)
result = engine.transform(self.trajectories)
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [22 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [22 + i for i in range(10)]
}
)
)
t3 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [32 + i for i in range(10)]
}
)
)
t4 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [32 + i for i in range(10)]
}
)
)
expected = [t1, t2, t3, t4]
self.assertEqual(len(expected), len(result))
for t1, t2 in zip(result, expected):
pd.testing.assert_frame_equal(t1.df, t2.df)
def test_transform_simple_flip(self):
engine = AugmentationEngine(flip_values=[None, 0], keep_original=False)
result = engine.transform(self.trajectories)
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 - i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [12 - i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
expected = [t1, t2]
self.assertEqual(len(expected), len(result))
for t1, t2 in zip(result, expected):
pd.testing.assert_frame_equal(t1.df, t2.df)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment