diff --git a/skais/process/data_augmentation/augmentation_engine.py b/skais/process/data_augmentation/augmentation_engine.py index d61ff1baba4ad46e9301d4b4cec461d62f629358..27004aa58f514f39de26a75242f8d92a4d53f490 100644 --- a/skais/process/data_augmentation/augmentation_engine.py +++ b/skais/process/data_augmentation/augmentation_engine.py @@ -1,25 +1,31 @@ import tqdm as tqdm +from skais.process.data_augmentation.data_transformer import DataTransformer from skais.process.data_augmentation.flip import Flip from skais.process.data_augmentation.pipeline import Pipeline from skais.process.data_augmentation.translator import Translator class AugmentationEngine: - def __init__(self, translation_values, flip_values): + def __init__(self, translation_values=None, flip_values=None, keep_original=True): self.pipelines = [] + if keep_original: + self.pipelines.append(DataTransformer()) - for tv_long, tv_lat in translation_values: - self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) + if translation_values is not None: + for tv_long, tv_lat in translation_values: + self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) - for fv_meridian, fv_parallel in flip_values: - self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) - - for tv_long, tv_lat in translation_values: - translator = Translator(tv_long, tv_lat) + if flip_values is not None: for fv_meridian, fv_parallel in flip_values: - flip = Flip(fv_meridian, fv_parallel) - self.pipelines.append(Pipeline([translator, flip])) + self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) + + if flip_values is not None and translation_values is not None: + for tv_long, tv_lat in translation_values: + translator = Translator(tv_long, tv_lat) + for fv_meridian, fv_parallel in flip_values: + flip = Flip(fv_meridian, fv_parallel) + self.pipelines.append(Pipeline([translator, flip])) def transform(self, x, verbose=0): results = x.copy() diff --git a/skais/process/data_augmentation/data_transformer.py b/skais/process/data_augmentation/data_transformer.py index 0e4e1c4292fd29ab4f9ee95e723ee478976c9c7e..a2c4332e65d65efdd310a0e17069498a2b6eecbd 100644 --- a/skais/process/data_augmentation/data_transformer.py +++ b/skais/process/data_augmentation/data_transformer.py @@ -1,3 +1,3 @@ class DataTransformer: - def transform(self, X): - pass + def transform(self, x): + return x diff --git a/skais/tests/process/data_augmentation/test_engine.py b/skais/tests/process/data_augmentation/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c77c55db68997644971bf9b2ce3dbf2a0b1e4e4d --- /dev/null +++ b/skais/tests/process/data_augmentation/test_engine.py @@ -0,0 +1,111 @@ +import unittest + +from skais.ais.ais_trajectory import AISTrajectory +from skais.process.data_augmentation.augmentation_engine import AugmentationEngine + +import pandas as pd + +class Test_Engine(unittest.TestCase): + def setUp(self): + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + + self.trajectories = [t1, t2] + + + def test_transform_simple_translation(self): + engine = AugmentationEngine(translation_values=[(10, 0), (20, 0)], keep_original=False) + + result = engine.transform(self.trajectories) + + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [22 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [22 + i for i in range(10)] + } + ) + ) + + t3 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 for _ in range(10)], + 'longitude': [32 + i for i in range(10)] + } + ) + ) + + t4 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [-12 + i for i in range(10)], + 'longitude': [32 + i for i in range(10)] + } + ) + ) + expected = [t1, t2, t3, t4] + + self.assertEqual(len(expected), len(result)) + for t1, t2 in zip(result, expected): + pd.testing.assert_frame_equal(t1.df, t2.df) + + def test_transform_simple_flip(self): + engine = AugmentationEngine(flip_values=[None, 0], keep_original=False) + + result = engine.transform(self.trajectories) + + t1 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [0 - i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + t2 = AISTrajectory( + pd.DataFrame( + { + 'ts_sec': [i for i in range(10)], + 'latitude': [12 - i for i in range(10)], + 'longitude': [12 + i for i in range(10)] + } + ) + ) + expected = [t1, t2] + + self.assertEqual(len(expected), len(result)) + for t1, t2 in zip(result, expected): + pd.testing.assert_frame_equal(t1.df, t2.df) + +if __name__ == '__main__': + unittest.main()