diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py index e6ba885dcafca82f0a023eee0196003472ce28b4..a08674ec4e24e3110fb4a35af84731bbb632f83c 100644 --- a/skais/ais/ais_trajectory.py +++ b/skais/ais/ais_trajectory.py @@ -236,22 +236,39 @@ class AISTrajectory(AISPoints): node_size=0): nb_channels = 1 - if bounding_box != 'fit': - raise ValueError("feature not implemented") - positions = self.df[['longitude', 'latitude']].to_numpy() - min_lon, max_lon = (min(positions[:, 0]), max(positions[:, 0])) - min_lat, max_lat = (min(positions[:, 1]), max(positions[:, 1])) - if min_lat == max_lat: - min_lat -= 1 - max_lat += 1 - if min_lon == max_lon: - min_lon -= 1 - max_lon += 1 + if bounding_box == 'fit': + positions = self.df[['longitude', 'latitude']].to_numpy() + lower_lon, upper_lon = (min(positions[:, 0]), max(positions[:, 0])) + lower_lat, upper_lat = (min(positions[:, 1]), max(positions[:, 1])) + elif bounding_box == 'centered': + positions = self.df[['longitude', 'latitude']].to_numpy() + center_lon, center_lat = positions[-1] + min_lon, max_lon = (min(positions[:, 0]), max(positions[:, 0])) + min_lat, max_lat = (min(positions[:, 1]), max(positions[:, 1])) + + distance_to_center = max(center_lon - min_lon, max_lon - center_lon, center_lat - min_lat, + max_lat - center_lat) + + upper_lat = center_lat + distance_to_center + lower_lat = center_lat - distance_to_center + upper_lon = center_lon + distance_to_center + lower_lon = center_lon - distance_to_center + + else: + raise ValueError(f"Option not supported: {bounding_box}") + + if lower_lat == upper_lat: + lower_lat -= 1 + upper_lat += 1 + if lower_lon == upper_lon: + lower_lon -= 1 + upper_lon += 1 if features is None: data = np.zeros((height, width, nb_channels), dtype=np.uint8) for longitude, latitude in positions: - x_coord, y_coord = get_coord(latitude, longitude, height, width, min_lat, max_lat, min_lon, max_lon) + x_coord, y_coord = get_coord(latitude, longitude, height, width, lower_lat, upper_lat, lower_lon, + upper_lon) x_lower_bound = max(0, x_coord - node_size) x_upper_bound = min(height - 1, x_coord + node_size) @@ -265,8 +282,9 @@ class AISTrajectory(AISPoints): if link: lon, lat = positions[0, 0], positions[0, 1] for longitude, latitude in positions[1:]: - x_prv, y_prev = get_coord(lat, lon, height, width, min_lat, max_lat, min_lon, max_lon) - x_nxt, y_nxt = get_coord(latitude, longitude, height, width, min_lat, max_lat, min_lon, max_lon) + x_prv, y_prev = get_coord(lat, lon, height, width, lower_lat, upper_lat, lower_lon, upper_lon) + x_nxt, y_nxt = get_coord(latitude, longitude, height, width, lower_lat, upper_lat, lower_lon, + upper_lon) lon, lat = longitude, latitude for x, y in bresenham(x_prv, y_prev, x_nxt, y_nxt): @@ -288,7 +306,8 @@ class AISTrajectory(AISPoints): for pos, f in zip(positions, features_vectors): latitude = pos[1] longitude = pos[0] - x_coord, y_coord = get_coord(latitude, longitude, height, width, min_lat, max_lat, min_lon, max_lon) + x_coord, y_coord = get_coord(latitude, longitude, height, width, lower_lat, upper_lat, lower_lon, + upper_lon) value = __get_image_value__(f, bounds) x_lower_bound = max(0, x_coord - node_size) x_upper_bound = min(height - 1, x_coord + node_size) @@ -307,8 +326,9 @@ class AISTrajectory(AISPoints): for pos, f in zip(positions[1:], features_vectors[1:]): latitude = pos[1] longitude = pos[0] - x_prv, y_prev = get_coord(lat, lon, height, width, min_lat, max_lat, min_lon, max_lon) - x_nxt, y_nxt = get_coord(latitude, longitude, height, width, min_lat, max_lat, min_lon, max_lon) + x_prv, y_prev = get_coord(lat, lon, height, width, lower_lat, upper_lat, lower_lon, upper_lon) + x_nxt, y_nxt = get_coord(latitude, longitude, height, width, lower_lat, upper_lat, lower_lon, + upper_lon) lon, lat = longitude, latitude for x, y in bresenham(x_prv, y_prev, x_nxt, y_nxt): for i, v in enumerate(value): diff --git a/skais/tests/ais/test_ais_trajectory.py b/skais/tests/ais/test_ais_trajectory.py index 4298ed2760e70bde033896996ba664487eebfeb8..68085c864532384f6046e8dfa98dcce3dfa12b07 100644 --- a/skais/tests/ais/test_ais_trajectory.py +++ b/skais/tests/ais/test_ais_trajectory.py @@ -554,15 +554,15 @@ class TestAISTrajectoryImageGeneration(unittest.TestCase): def test_generate_array_centered(self): result = self.trajectory.generate_array_from_positions(height=9, width=9, link=False, bounding_box='centered', features=None, node_size=0).reshape((9, 9)) - expected = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0], + expected = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0, 0, 0]]) + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0]]) np.testing.assert_array_equal(result, expected)