Skip to content
Snippets Groups Projects
Commit 5e01b527 authored by Raphael's avatar Raphael
Browse files

improvements to test engine + tests

parent 243e8f69
No related branches found
No related tags found
2 merge requests!12version 0.2a,!10Resolve "Image creation bugs with 0 size windows"
This commit is part of merge request !10. Comments created here will be created in the context of that merge request.
import tqdm as tqdm import tqdm as tqdm
from skais.process.data_augmentation.data_transformer import DataTransformer
from skais.process.data_augmentation.flip import Flip from skais.process.data_augmentation.flip import Flip
from skais.process.data_augmentation.pipeline import Pipeline from skais.process.data_augmentation.pipeline import Pipeline
from skais.process.data_augmentation.translator import Translator from skais.process.data_augmentation.translator import Translator
class AugmentationEngine: class AugmentationEngine:
def __init__(self, translation_values, flip_values): def __init__(self, translation_values=None, flip_values=None, keep_original=True):
self.pipelines = [] self.pipelines = []
if keep_original:
self.pipelines.append(DataTransformer())
if translation_values is not None:
for tv_long, tv_lat in translation_values: for tv_long, tv_lat in translation_values:
self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)])) self.pipelines.append(Pipeline([Translator(tv_long, tv_lat)]))
if flip_values is not None:
for fv_meridian, fv_parallel in flip_values: for fv_meridian, fv_parallel in flip_values:
self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)])) self.pipelines.append(Pipeline([Flip(fv_meridian, fv_parallel)]))
if flip_values is not None and translation_values is not None:
for tv_long, tv_lat in translation_values: for tv_long, tv_lat in translation_values:
translator = Translator(tv_long, tv_lat) translator = Translator(tv_long, tv_lat)
for fv_meridian, fv_parallel in flip_values: for fv_meridian, fv_parallel in flip_values:
......
class DataTransformer: class DataTransformer:
def transform(self, X): def transform(self, x):
pass return x
import unittest
from skais.ais.ais_trajectory import AISTrajectory
from skais.process.data_augmentation.augmentation_engine import AugmentationEngine
import pandas as pd
class Test_Engine(unittest.TestCase):
def setUp(self):
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
self.trajectories = [t1, t2]
def test_transform_simple_translation(self):
engine = AugmentationEngine(translation_values=[(10, 0), (20, 0)], keep_original=False)
result = engine.transform(self.trajectories)
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [22 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [22 + i for i in range(10)]
}
)
)
t3 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 for _ in range(10)],
'longitude': [32 + i for i in range(10)]
}
)
)
t4 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [-12 + i for i in range(10)],
'longitude': [32 + i for i in range(10)]
}
)
)
expected = [t1, t2, t3, t4]
self.assertEqual(len(expected), len(result))
for t1, t2 in zip(result, expected):
pd.testing.assert_frame_equal(t1.df, t2.df)
def test_transform_simple_flip(self):
engine = AugmentationEngine(flip_values=[None, 0], keep_original=False)
result = engine.transform(self.trajectories)
t1 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [0 - i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
t2 = AISTrajectory(
pd.DataFrame(
{
'ts_sec': [i for i in range(10)],
'latitude': [12 - i for i in range(10)],
'longitude': [12 + i for i in range(10)]
}
)
)
expected = [t1, t2]
self.assertEqual(len(expected), len(result))
for t1, t2 in zip(result, expected):
pd.testing.assert_frame_equal(t1.df, t2.df)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment