From 688f410922c6befa9f045cedc7ca4c1937b171f9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Lehnhoff?= <loic.lehnhoff@gmail.com>
Date: Mon, 25 Mar 2024 12:46:23 +0100
Subject: [PATCH] Small changes

~ Modification of README
~ script optimizations
~ minor bug corrections
---
 PyAVA.py                          | 11 ++++++++++-
 README.md                         | 13 +++++++------
 args.py                           | 16 +++++++++++++++-
 functions.py                      | 31 ++++++++++++++++++++++++-------
 interface.py                      |  4 ++--
 line_clicker/lite_line_clicker.py |  1 -
 6 files changed, 58 insertions(+), 18 deletions(-)

diff --git a/PyAVA.py b/PyAVA.py
index d1dfe30..51d9fb4 100644
--- a/PyAVA.py
+++ b/PyAVA.py
@@ -9,7 +9,7 @@ from args import fetch_inputs
 ##### MAIN #####
 if __name__ == '__main__':
     # fetching inputs.
-    dir_explore, max_traj, new_sr, output, modify, initial_basename, parameters = fetch_inputs()
+    audio_file_path, dir_explore, max_traj, new_sr, output, modify, initial_basename, parameters = fetch_inputs()
 
     if modify:
         with open(os.path.join(output, modify), "r") as f:
@@ -24,6 +24,15 @@ if __name__ == '__main__':
             coords_to_change,
             parameters)
 
+    elif len(audio_file_path)>0:
+        MainWindow = App(
+            dir_explore, 
+            max_traj, 
+            new_sr, 
+            output, 
+            audio_file_path,
+            {},
+            parameters)
     else:
         # open explorer to select first file
         groot = Tk()
diff --git a/README.md b/README.md
index f7fba7b..290b802 100644
--- a/README.md
+++ b/README.md
@@ -12,12 +12,13 @@
 - [x] Same tools as matplotlib.pyplot plots.
 - [x] Spectrogram contour annotations. 
 - [x] Spectrogram automatically computed from waveform.
-- [x] Choose custom spectrogram resolutions (fft, hop length, clipping dB value and PCEN).
-- [x] Select a new file directly from the interface.
+- [x] Choose custom spectrogram resolutions (fft, hop length, clipping of lowest dB values and PCEN).
 - [x] Exportation of contours to local `.json`.
-- [x] Move points once they are placed (with mouse wheel).
-- [x] Modification of previous annotation. Save & return later!
+- [x] Possibility to move points that are already placed.
+- [x] Modification of previous annotation files. Save & return later!
 
+## Known issues
+- laggy when high number of contours/points per contour
 
 ## Requirements
 
@@ -31,11 +32,11 @@ Install packages in your python environment with `$ pip install -r requirements.
 ## Usage
 
 ### Execution
-For classic use, download PyAVA folder, then open a terminal in the folder and run `$ python PyAVA.py -dir myWavefileFolder -out myOutputFolder` in terminal.  
+For classic usage, download PyAVA folder, then open a terminal in the folder and run `$ python PyAVA.py -dir myWavefileFolder -out myOutputFolder` in terminal.  
 
 Run `$python PyAVA.py --help` for details.  
 
-The annotations are saved in [JSON](http://www.json.org/) files. Each file contains a dictionnary with the categories annotated. For each category there is a list of points, each point is defined by a list of two elements : [time (in sec), frequency (in Hz)].
+The annotations are saved as [JSON](http://www.json.org/) files. Each file contains a dictionnary with the categories annotated. For each category there is a list of points, each point is defined by a list of two elements : [time (in sec), frequency (in Hz)].
 
 ### User actions
 - Use the toolbar to interact with the plot (same as with matplotlib.pyplot)
diff --git a/args.py b/args.py
index 4c34ade..b3eb7ff 100644
--- a/args.py
+++ b/args.py
@@ -53,6 +53,16 @@ def fetch_inputs():
 		"\nDefault value is '.' (current Directory).\n\n")
 	)
 
+	parser.add_argument(
+		'-f', '--audio_file',
+		type=str,
+		default="",
+		nargs='?', 
+		required=False,
+		help=("Audio file to annotate (wavefile).\n"
+		"If not give, will open a file explorer.\n\n")
+	)
+
 	parser.add_argument(
 		'-max', '--max_contours',
 		type=int,
@@ -111,6 +121,7 @@ def fetch_inputs():
 	explore = args.directory
 	contour = args.max_contours
 	resampl = args.resampling_rate
+	audio_file = args.audio_file
 	if args.modify == None:
 		modify_file, from_wav = False, False
 	else:
@@ -121,6 +132,9 @@ def fetch_inputs():
 	try:
 		assert (os.path.exists(outputs)), (
 			f"\nInputError: Could not find dir '{outputs}'.")
+		if len(audio_file)>0:
+			assert (os.path.exists(audio_file)), (
+				f"\nInputError: Could not find file '{audio_file}'.")
 		if isinstance(args.parameters, str):
 			assert (os.path.exists(args.parameters)), (
 				f"\nInputError: Could not find dir '{args.parameters}'.")
@@ -137,7 +151,7 @@ def fetch_inputs():
 		print(e)
 		sys.exit(1)
 
-	return (explore, contour, resampl, outputs, modify_file, from_wav, args.parameters)
+	return (audio_file, explore, contour, resampl, outputs, modify_file, from_wav, args.parameters)
 
 # if running `$python ARGS.py -h` for help.
 if __name__ == '__main__':
diff --git a/functions.py b/functions.py
index 202bbce..5edd864 100644
--- a/functions.py
+++ b/functions.py
@@ -3,7 +3,7 @@ import os
 import json
 import numpy as np
 
-from librosa import load, amplitude_to_db, stft, pcen
+from librosa import load, amplitude_to_db, stft, pcen, to_mono
 from scipy.signal import resample
 
 ##### FUNCTIONS #####
@@ -43,7 +43,7 @@ def save_dict(dictionary, folder, name, contours=True):
         with open(os.path.join(folder, name), "w") as f:
             json.dump(dictionary, f, indent=4) 
 
-def load_waveform(wavefile_name, sr_resample):
+def load_waveform(wavefile_name, sr_resample, channel="all"):
     """
     A function that loads any given wavefile 
     and it resamples it to a given sampling rate.
@@ -56,18 +56,35 @@ def load_waveform(wavefile_name, sr_resample):
         Path of the wavefile that will be loaded.
     sr_resample : int
         Resampling rate for the waveform.
+    channel : str, int or np.ndarray
+        "all" or int(s) corresponding to the channel to import.
+        The integer(s) should correspond to the index of the channel(s) to select.
+        Default is "all".
 
     Returns
     -------
     wavefile_dec : numpy array
         Loaded and resampled waveform
     """
-    wavefile, sr = load(wavefile_name, sr=None)
-    wavefile_dec = resample(wavefile, 
-        int(((len(wavefile)/sr)*sr_resample)))
-
-    return wavefile_dec
+    if channel == "all":
+        waveform, sr = load(wavefile_name, sr=None)
+
+    elif isinstance(channel, int):
+        waveforms, sr = load(wavefile_name, sr=None, mono=False)
+        waveform = np.copy(waveforms[channel])
+    
+    elif isinstance(channel, np.ndarray):
+        waveforms, sr = load(wavefile_name, sr=None, mono=False)
+        waveforms = np.copy(waveforms[channel])
+        waveform = to_mono(waveforms)
+    
+    else:
+        raise ValueError(f"Channel '{channel}' unknown. Should be 'all', an integer or an array of integers.")
 
+    waveform_dec = resample(waveform, 
+        int(((len(waveform)/sr)*sr_resample)))
+    return waveform_dec
+    
 def wave_to_spectrogram(waveform, SR, n_fft, w_size, clip, as_pcen=False, top_db=160):
     """
     A function that transforms any given waveform to a spectrogram.
diff --git a/interface.py b/interface.py
index 1921c60..24e7dcc 100644
--- a/interface.py
+++ b/interface.py
@@ -305,7 +305,7 @@ class App(object):
         self.figure_bboxes = []
         self.klicker = clicker(
             axis=self.axis,
-            names=["Line" + str(i+1) for i in range(self.NAME0, self.NAME1)],
+            names=["DefaultName" + str(i+1) for i in range(self.NAME0, self.NAME1)],
             bspline='quadratic', maxlines=99, legend_bbox=(2,0.5),
             coords=coords_to_modify)
 
@@ -847,7 +847,7 @@ class App(object):
                 new_name if key==old_name else key:value 
                 for key,value in self.klicker.coords.items()}
 
-            self.klicker._set_legend()
+            self.klicker._create_lines_and_points()
             self.klicker.current_line = index_item
             self.klicker.update_lines()
 
diff --git a/line_clicker/lite_line_clicker.py b/line_clicker/lite_line_clicker.py
index ca0271b..96c5a80 100644
--- a/line_clicker/lite_line_clicker.py
+++ b/line_clicker/lite_line_clicker.py
@@ -81,7 +81,6 @@ class clicker(object):
 		maxlines : int, optional.
 			Maximum number of labels in legend.
 			Default is 30.
-			Default is False.
 		names : str or list, optional.
 			Names for legend labels.
 			If string, names will be name+{line number}
-- 
GitLab