From 5e01b5271f266e9a77271a763e8f52f97ea29ea8 Mon Sep 17 00:00:00 2001 From: Raphael <raphael.sturgis@gmail.com> Date: Tue, 7 Dec 2021 08:48:21 +0100 Subject: [PATCH] improvements to test engine + tests --- .../data_augmentation/augmentation_engine.py | 26 ++-- .../data_augmentation/data_transformer.py | 4 +- .../process/data_augmentation/test_engine.py | 111 ++++++++++++++++++ 3 files changed, 129 insertions(+), 12 deletions(-) create mode 100644 skais/tests/process/data_augmentation/test_engine.py diff --git a/skais/process/data_augmentation/augmentation_engine.py b/skais/process/data_augmentation/augmentation_engine.py index d61ff1b..27004aa 100644 --- a/skais/process/data_augmentation/augmentation_engine.py +++ b/skais/process/data_augmentation/augmentation_engine.py @@ -1,25 +1,31 @@ import tqdm as tqdm +from skais.process.data_augmentation.data_transformer import DataTransformer from skais.process.data_augmentation.flip import Flip from skais.process.data_augmentation.pipeline import Pipeline from skais.process.data_augmentation.translator import Translator class AugmentationEngine: - def __init__(self, translation_values, flip_values): + def __init__(self, translation_values=None, flip_values=None, keep_original=True): self.pipelines = [] + if keep_original: + self.pipelines.append(DataTransformer()) - for tv_long, tv_lat in translation_values: - self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) + if translation_values is not None: + for tv_long, tv_lat in translation_values: + self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) - for fv_meridian, fv_parallel in flip_values: - self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) - - for tv_long, tv_lat in translation_values: - translator = Translator(tv_long, tv_lat) + if flip_values is not None: for fv_meridian, fv_parallel in flip_values: - flip = Flip(fv_meridian, fv_parallel) - self.pipelines.append(Pipeline([translator, flip])) + self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) + + if flip_values is not None and translation_values is not None: + for tv_long, tv_lat in translation_values: + translator = Translator(tv_long, tv_lat) + for fv_meridian, fv_parallel in flip_values: + flip = Flip(fv_meridian, fv_parallel) + self.pipelines.append(Pipeline([translator, flip])) def transform(self, x, verbose=0): results = x.copy() diff --git a/skais/process/data_augmentation/data_transformer.py b/skais/process/data_augmentation/data_transformer.py index 0e4e1c4..a2c4332 100644 --- a/skais/process/data_augmentation/data_transformer.py +++ b/skais/process/data_augmentation/data_transformer.py @@ -1,3 +1,3 @@ class DataTransformer: - def transform(self, X): - pass + def transform(self, x): + return x diff --git a/skais/tests/process/data_augmentation/test_engine.py b/skais/tests/process/data_augmentation/test_engine.py new file mode 100644 index 0000000..c77c55d --- /dev/null +++ b/skais/tests/process/data_augmentation/test_engine.py @@ -0,0 +1,111 @@ +import unittest + +from skais.ais.ais_trajectory import AISTrajectory +from skais.process.data_augmentation.augmentation_engine import AugmentationEngine + +import pandas as pd + +class Test_Engine(unittest.TestCase): + def setUp(self): + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + + self.trajectories = [t1, t2] + + + def test_transform_simple_translation(self): + engine = AugmentationEngine(translation_values=[(10, 0), (20, 0)], keep_original=False) + + result = engine.transform(self.trajectories) + + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [22 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [22 + i for i in range(10)] + } + ) + ) + + t3 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [32 + i for i in range(10)] + } + ) + ) + + t4 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [32 + i for i in range(10)] + } + ) + ) + expected = [t1, t2, t3, t4] + + self.assertEqual(len(expected), len(result)) + for t1, t2 in zip(result, expected): + pd.testing.assert_frame_equal(t1.df, t2.df) + + def test_transform_simple_flip(self): + engine = AugmentationEngine(flip_values=[None, 0], keep_original=False) + + result = engine.transform(self.trajectories) + + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 - i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [12 - i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + expected = [t1, t2] + + self.assertEqual(len(expected), len(result)) + for t1, t2 in zip(result, expected): + pd.testing.assert_frame_equal(t1.df, t2.df) + +if __name__ == '__main__': + unittest.main() -- GitLab