From 5b5a2d326e1cdf84fdc6b937afa8524118e96afd Mon Sep 17 00:00:00 2001 From: Raphael Sturgis <araphael.sturgis@lis-lab.fr> Date: Thu, 17 Mar 2022 10:26:10 +0100 Subject: [PATCH] added tests and function prototype for getting time stamps' for label changes --- skais/ais/ais_trajectory.py | 3 ++ skais/tests/ais/test_ais_trajectory.py | 47 +++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py index 709440b..7fb2a11 100644 --- a/skais/ais/ais_trajectory.py +++ b/skais/ais/ais_trajectory.py @@ -214,3 +214,6 @@ class AISTrajectory(AISPoints): return self else: return AISTrajectory(new_df, mmsi=self.mmsi) + + def get_time_per_label_shift(self): + pass \ No newline at end of file diff --git a/skais/tests/ais/test_ais_trajectory.py b/skais/tests/ais/test_ais_trajectory.py index d4bb0d6..0566fca 100644 --- a/skais/tests/ais/test_ais_trajectory.py +++ b/skais/tests/ais/test_ais_trajectory.py @@ -333,4 +333,49 @@ class TestAISTrajectory(unittest.TestCase): expected = np.array([2, 2, 2, 2, 2]) - np.testing.assert_equal(result, expected) \ No newline at end of file + np.testing.assert_equal(result, expected) + + def test_get_time_per_label_shift_single_label(self): + trajectory = AISTrajectory( + pd.DataFrame( + { + "label": [1 for _ in range(0, 101, 10)], + "ts_sec": [i for i in range(0, 6001, 600)] + } + ) + ) + + result = trajectory.get_time_per_label_shift() + expected = [(0, 1)] + + self.assertListEqual(result, expected) + + def test_get_time_per_label_shift_label_switch_1(self): + trajectory = AISTrajectory( + pd.DataFrame( + { + "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)], + "ts_sec": [i for i in range(0, 12001, 600)] + } + ) + ) + + result = trajectory.get_time_per_label_shift() + expected = [(0, 1), (6600, 2)] + + self.assertListEqual(result, expected) + + def test_get_time_per_label_shift_label_switch_2(self): + trajectory = AISTrajectory( + pd.DataFrame( + { + "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)]+ [1 for _ in range(0, 101, 10)], + "ts_sec": [i for i in range(0, 18001, 600)] + } + ) + ) + + result = trajectory.get_time_per_label_shift() + expected = [(0, 1), (6600, 2), (12600, 1)] + + self.assertListEqual(result, expected) \ No newline at end of file -- GitLab