diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py index 5b2af983b24f0d8012b82f294d89a1e0041c6b65..709440b93d837c1bf820a46c5d694744fa6fec55 100644 --- a/skais/ais/ais_trajectory.py +++ b/skais/ais/ais_trajectory.py @@ -39,6 +39,13 @@ def apply_func_on_window(dat, func, radius, on_edge='copy'): data = dat[i - radius:i + radius + 1] result[i - radius] = func(data) return result + elif on_edge == 'ignore': + for i in range(0, dat.shape[0]): + lower_bound = max(0, i-radius) + upper_bound = min(dat.shape[0], i + radius + 1) + data = dat[lower_bound:upper_bound] + result[i] = func(data) + return result else: raise ValueError @@ -96,9 +103,9 @@ class AISTrajectory(AISPoints): return result - def apply_func_on_time_window(self, func, radius, column, new_column=None): + def apply_func_on_time_window(self, func, radius, column, new_column=None, on_edge='copy'): dat = self.df[column].to_numpy() - result = apply_func_on_window(dat, func, radius, on_edge='copy') + result = apply_func_on_window(dat, func, radius, on_edge) if new_column is None: self.df[column] = result diff --git a/skais/tests/ais/test_ais_trajectory.py b/skais/tests/ais/test_ais_trajectory.py index 1ee2d6d860dd0a5a9938f9fc61762eb3b4320963..d4bb0d629c31fccbae8833eef65d4b19ab0db58d 100644 --- a/skais/tests/ais/test_ais_trajectory.py +++ b/skais/tests/ais/test_ais_trajectory.py @@ -320,3 +320,17 @@ class TestAISTrajectory(unittest.TestCase): def test_apply_func_on_window(self): self.assertRaises(ValueError, apply_func_on_window,np.arange(10), 0, 0, 'not valid string') + + def test_apply_func_on_window_ignore(self): + result = apply_func_on_window(np.arange(10), np.mean, 2, 'ignore') + + expected = np.array([1, 1.5, 2, 3, 4, 5, 6, 7, 7.5, 8]) + + np.testing.assert_equal(result, expected) + + def test_apply_func_on_window_ignore_short(self): + result = apply_func_on_window(np.arange(5), np.mean, 10, 'ignore') + + expected = np.array([2, 2, 2, 2, 2]) + + np.testing.assert_equal(result, expected) \ No newline at end of file