From f6dcbe94f7de009b9bbee7b18de37e6c4724a17e Mon Sep 17 00:00:00 2001
From: Raphael Sturgis <raphael.sturgis@lis-lab.fr>
Date: Mon, 11 Apr 2022 11:36:38 +0200
Subject: [PATCH] set bounding box with 2 numbers

---
 skais/ais/ais_trajectory.py | 22 ++++++++++++++++------
 1 file changed, 16 insertions(+), 6 deletions(-)

diff --git a/skais/ais/ais_trajectory.py b/skais/ais/ais_trajectory.py
index 7e1307e..a20abf4 100644
--- a/skais/ais/ais_trajectory.py
+++ b/skais/ais/ais_trajectory.py
@@ -1,3 +1,4 @@
+import numbers
 import random
 
 import pandas as pd
@@ -232,7 +233,7 @@ class AISTrajectory(AISPoints):
                 result.append((row['ts_sec'], current_label))
         return result
 
-    def generate_array_from_positions(self, height=256, width=256, link=True, bounding_box='fit', features=None,
+    def generate_array_from_positions(self, height=256, width=256, link=True, bounding_box='fit', ref_index=-1, features=None,
                                       node_size=0):
         nb_channels = 1
 
@@ -241,7 +242,7 @@ class AISTrajectory(AISPoints):
             lower_lon, upper_lon = (min(positions[:, 0]), max(positions[:, 0]))
             lower_lat, upper_lat = (min(positions[:, 1]), max(positions[:, 1]))
         elif bounding_box == 'centered':
-            center_lon, center_lat = positions[-1]
+            center_lon, center_lat = positions[ref_index]
             min_lon, max_lon = (min(positions[:, 0]), max(positions[:, 0]))
             min_lat, max_lat = (min(positions[:, 1]), max(positions[:, 1]))
 
@@ -253,10 +254,19 @@ class AISTrajectory(AISPoints):
             upper_lon = center_lon + distance_to_center
             lower_lon = center_lon - distance_to_center
         elif type(bounding_box) is list:
-            upper_lon = bounding_box[1][0]
-            lower_lon = bounding_box[0][0]
-            upper_lat = bounding_box[1][1]
-            lower_lat = bounding_box[0][1]
+            if type(bounding_box[0]) is not numbers.Number:
+                upper_lon = bounding_box[1][0]
+                lower_lon = bounding_box[0][0]
+                upper_lat = bounding_box[1][1]
+                lower_lat = bounding_box[0][1]
+            else:
+                center_lon, center_lat = positions[ref_index]
+                distance_to_center_lon = bounding_box[0]
+                distance_to_center_lat = bounding_box[1]
+                upper_lat = center_lat + distance_to_center_lat
+                lower_lat = center_lat - distance_to_center_lat
+                upper_lon = center_lon + distance_to_center_lon
+                lower_lon = center_lon - distance_to_center_lon
         else:
             raise ValueError(f"Option not supported: {bounding_box}")
 
-- 
GitLab