From 279b542a5b77698ba94f80706e5c0cd594626949 Mon Sep 17 00:00:00 2001 From: Stephane Chavin <stephane.chavin@lis-lab.fr> Date: Fri, 13 Dec 2024 15:34:20 +0100 Subject: [PATCH] correct --- get_train_val.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/get_train_val.py b/get_train_val.py index 988d96e..3393dc1 100755 --- a/get_train_val.py +++ b/get_train_val.py @@ -4,6 +4,7 @@ import argparse import os import pandas as pd import utils +import glob from tqdm import tqdm @@ -24,7 +25,7 @@ def export_split(argument, entry, path, directory): if argument.test: test_set = entry[2] - test_set.file = ['.'.join(x.split('.')[:-1]) + test_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1]) for num, x in enumerate(test.file)] utils.copy_files_to_directory(test_set.file, path, os.path.join( @@ -32,9 +33,9 @@ def export_split(argument, entry, path, directory): utils.copy_files_to_directory(test_set.file, os.path.join( path, '../images/all'), os.path.join(directory, 'images/test'), 'jpg') - val_set.file = ['.'.join(x.split('.')[:-1]) + val_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1]) for _, x in enumerate(val.file)] - train_set.file = ['.'.join(x.split('.')[:-1]) + train_set.file = ['.'.join(x.split('/')[-1].split('.')[:-1]) for _, x in enumerate(train_set.file)] # Copy the validation set into the folder @@ -71,7 +72,7 @@ def prepare_data(arguments): """ annotations = [] background = [] - for f in tqdm(arguments.path_to_data, desc="Processing", + for f in tqdm(glob.glob(os.path.join(arguments.path_to_data)), desc="Processing", ascii='░▒▓█'): file_annotation = pd.read_csv(f, sep=' ', names=['species', 'x', 'y', 'w', 'h']) if len(file_annotation) == 0: @@ -120,7 +121,7 @@ if __name__ == '__main__': print(f'Train saved in {saved_directory}\n') print('To train your model, use the following command : \n') - current_path = os.getcwd() + current_path = os.path.abspath(os.path.dirname( __file__ )) directory_path = os.path.join(current_path, saved_directory) @@ -131,7 +132,7 @@ if __name__ == '__main__': command = f'python {yolo_path} --data {data_path} --imgsz 640 --epochs 100 --weights {weights_path} --hyp {hyp_path} --cache' print(command,'\n') - if len(background == 0): + if len(background) == 0: print('\u26A0\uFE0F Be aware that it is recommended to have background images that', 'represents 10% of your dataset. If you do not have background, use the script "get_spectrogram.py"', 'with --background arguments. Comptue on recordings that contains multiple type of noise...') -- GitLab