From 5b5a2d326e1cdf84fdc6b937afa8524118e96afd Mon Sep 17 00:00:00 2001
From: Raphael Sturgis <araphael.sturgis@lis-lab.fr>
Date: Thu, 17 Mar 2022 10:26:10 +0100
Subject: [PATCH] added tests and function prototype for getting time stamps'
 for label changes

---
 skais/ais/ais_trajectory.py            |  3 ++
 skais/tests/ais/test_ais_trajectory.py | 47 +++++++++++++++++++++++++-
 2 files changed, 49 insertions(+), 1 deletion(-)

diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py
index 709440b..7fb2a11 100644
--- a/skais/ais/ais_trajectory.py
+++ b/skais/ais/ais_trajectory.py
@@ -214,3 +214,6 @@ class AISTrajectory(AISPoints):
             return self
         else:
             return AISTrajectory(new_df, mmsi=self.mmsi)
+
+    def get_time_per_label_shift(self):
+        pass
\ No newline at end of file
diff --git a/skais/tests/ais/test_ais_trajectory.py b/skais/tests/ais/test_ais_trajectory.py
index d4bb0d6..0566fca 100644
--- a/skais/tests/ais/test_ais_trajectory.py
+++ b/skais/tests/ais/test_ais_trajectory.py
@@ -333,4 +333,49 @@ class TestAISTrajectory(unittest.TestCase):
 
         expected = np.array([2, 2, 2, 2, 2])
 
-        np.testing.assert_equal(result, expected)
\ No newline at end of file
+        np.testing.assert_equal(result, expected)
+
+    def test_get_time_per_label_shift_single_label(self):
+        trajectory = AISTrajectory(
+            pd.DataFrame(
+                {
+                    "label": [1 for _ in range(0, 101, 10)],
+                    "ts_sec": [i for i in range(0, 6001, 600)]
+                }
+            )
+        )
+
+        result = trajectory.get_time_per_label_shift()
+        expected = [(0, 1)]
+
+        self.assertListEqual(result, expected)
+
+    def test_get_time_per_label_shift_label_switch_1(self):
+        trajectory = AISTrajectory(
+            pd.DataFrame(
+                {
+                    "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)],
+                    "ts_sec": [i for i in range(0, 12001, 600)]
+                }
+            )
+        )
+
+        result = trajectory.get_time_per_label_shift()
+        expected = [(0, 1), (6600, 2)]
+
+        self.assertListEqual(result, expected)
+
+    def test_get_time_per_label_shift_label_switch_2(self):
+        trajectory = AISTrajectory(
+            pd.DataFrame(
+                {
+                    "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)]+ [1 for _ in range(0, 101, 10)],
+                    "ts_sec": [i for i in range(0, 18001, 600)]
+                }
+            )
+        )
+
+        result = trajectory.get_time_per_label_shift()
+        expected = [(0, 1), (6600, 2), (12600, 1)]
+
+        self.assertListEqual(result, expected)
\ No newline at end of file
-- 
GitLab