From 279b542a5b77698ba94f80706e5c0cd594626949 Mon Sep 17 00:00:00 2001
From: Stephane Chavin <stephane.chavin@lis-lab.fr>
Date: Fri, 13 Dec 2024 15:34:20 +0100
Subject: [PATCH] correct

---
 get_train_val.py | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/get_train_val.py b/get_train_val.py
index 988d96e..3393dc1 100755
--- a/get_train_val.py
+++ b/get_train_val.py
@@ -4,6 +4,7 @@ import argparse
 import os
 import pandas as pd
 import utils
+import glob
 
 from tqdm import tqdm
 
@@ -24,7 +25,7 @@ def export_split(argument, entry, path, directory):
 
     if argument.test:
         test_set = entry[2]
-        test_set.file = ['.'.join(x.split('.')[:-1])
+        test_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                          for num, x in enumerate(test.file)]
 
         utils.copy_files_to_directory(test_set.file, path, os.path.join(
@@ -32,9 +33,9 @@ def export_split(argument, entry, path, directory):
         utils.copy_files_to_directory(test_set.file, os.path.join(
             path, '../images/all'), os.path.join(directory, 'images/test'), 'jpg')
 
-    val_set.file = ['.'.join(x.split('.')[:-1])
+    val_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                     for _, x in enumerate(val.file)]
-    train_set.file = ['.'.join(x.split('.')[:-1])
+    train_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1])
                       for _, x in enumerate(train_set.file)]
 
     # Copy the validation set into the folder
@@ -71,7 +72,7 @@ def prepare_data(arguments):
     """
     annotations = [] 
     background = []  
-    for f in tqdm(arguments.path_to_data, desc="Processing", 
+    for f in tqdm(glob.glob(os.path.join(arguments.path_to_data)), desc="Processing", 
                                                             ascii='░▒▓█'):
         file_annotation = pd.read_csv(f, sep=' ', names=['species', 'x', 'y', 'w', 'h'])
         if len(file_annotation) == 0:
@@ -120,7 +121,7 @@ if __name__ == '__main__':
     print(f'Train saved in {saved_directory}\n')
     print('To train your model, use the following command : \n')
 
-    current_path = os.getcwd()
+    current_path = os.path.abspath(os.path.dirname( __file__ ))
 
     directory_path = os.path.join(current_path, saved_directory)
 
@@ -131,7 +132,7 @@ if __name__ == '__main__':
 
     command = f'python {yolo_path} --data {data_path} --imgsz 640 --epochs 100 --weights {weights_path} --hyp {hyp_path} --cache'
     print(command,'\n')
-    if len(background == 0):
+    if len(background) == 0:
         print('\u26A0\uFE0F   Be aware that it is recommended to have background images that',
         'represents 10% of your dataset. If you do not have background, use the script "get_spectrogram.py"',
         'with --background arguments. Comptue on recordings that contains multiple type of noise...')
-- 
GitLab