diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py index 7fb2a1198d4d18373c174bbcd710dea0da39ec8c..f8364722878ef74364cb05f363da1ed002425b83 100644 --- a/skais/ais/ais_trajectory.py +++ b/skais/ais/ais_trajectory.py @@ -215,5 +215,11 @@ class AISTrajectory(AISPoints): else: return AISTrajectory(new_df, mmsi=self.mmsi) - def get_time_per_label_shift(self): - pass \ No newline at end of file + def get_time_per_label_shift(self, label_column='label'): + current_label = -1 + result = [] + for index, row in self.df.iterrows(): + if current_label != row[label_column]: + current_label = row[label_column] + result.append((row['ts_sec'], current_label)) + return result \ No newline at end of file diff --git a/skais/tests/ais/test_ais_trajectory.py b/skais/tests/ais/test_ais_trajectory.py index 0566fca5f8f9b65f478f51b05d9423d68e2a1504..bea4d8843011eaabc1ae0b81bf1dc023b3e62706 100644 --- a/skais/tests/ais/test_ais_trajectory.py +++ b/skais/tests/ais/test_ais_trajectory.py @@ -354,7 +354,7 @@ class TestAISTrajectory(unittest.TestCase): trajectory = AISTrajectory( pd.DataFrame( { - "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)], + "label": [1 for _ in range(11)] + [2 for _ in range(10)], "ts_sec": [i for i in range(0, 12001, 600)] } ) @@ -369,7 +369,7 @@ class TestAISTrajectory(unittest.TestCase): trajectory = AISTrajectory( pd.DataFrame( { - "label": [1 for _ in range(0, 101, 10)] + [2 for _ in range(0, 101, 10)]+ [1 for _ in range(0, 101, 10)], + "label": [1 for _ in range(11)] + [2 for _ in range(10)]+ [1 for _ in range(10)], "ts_sec": [i for i in range(0, 18001, 600)] } )