Skip to content
Snippets Groups Projects
Select Git revision
  • 9329c5fd78f3a6e7734b12d002341e8ffe959b83
  • main default protected
2 results

post_annotation.py

Blame
  • user avatar
    Loic-Lenof authored
    + points can be moved (mouse wheel)
    + added PCEN
    + added bspline
    + added modification of annotations after saving
    + access to saved data simplified
    - removed legend to access categories
    + added listbox to access categories
    - removed button to add categories
    + added shortcut to add categories one by one
    
    ~ replaced "classes.py" with "interface.py"
        |-> many changes (style + shortcuts + simplifications)
    ~ updated documentation and organisation
    ~ updated readme
    9329c5fd
    History
    post_annotation.py 4.62 KiB
    ##### IMPORTATIONS #####
    import os
    import json
    import numpy as np
    
    from matplotlib.patches import Rectangle
    import matplotlib.pyplot as plt
    
    from line_clicker.line_clicker import to_curve
    
    # Import external functions
    from functions import load_waveform, wave_to_spectrogram, save_dict
    
    ##### CLASS ######
    class Results(object):
    	"""
    	SR : int, optional.
    		Sampling rate of the waveform.
    		Default is 96 kHz.
    	n_fft : int, optional.
    		Desired size for fft window. Should be in [1, N-1].
    		Default is 4098.
    	w_size : int, optional.
    		Desired size for hop length between two fft. Should be in [1, N-1].
    		Default is 156.
    	clip : int, optional.
    		Clipping value for dB. If pixel value < clip, pixel is turned into NaN.
    		Default is -80.
    	cmap : str, optional.
    		Color map for matplotlib.pyplot plot.
    		Default is viridis.
    	"""
    
    	colors = [		# Colors for categories. Will cycle through them.
    		'#1f77b4',	# List can be appened or reduced.
    		'#ff7f0e',
    		'#2ca02c',
    		'#d62728',
    		'#9467bd',
    		'#8c564b',
    		'#e377c2',
    		'#7f7f7f',
    		'#bcbd22',
    		'#17becf'
    		]
    
    	def __init__(
    		self,
    		wavefile_name,
    		jsonfile_name,
    		SR = 96_000,
    		NFFT = 4098,
    		HOP_LENGTH = 156,
    		CLIPPING = -80,
    		cmap = 'viridis'):
    
    		self.wavefile_name = wavefile_name
    		self.jsonfile_name = jsonfile_name
    		self.SR = SR
    		self.NFFT = NFFT
    		self.HOP_LENGTH = HOP_LENGTH
    		self.CLIPPING = CLIPPING
    		self.cmap = cmap
    
    		self.coords = self.load_contours_file()
    		self.waveform = load_waveform(self.wavefile_name, self.SR)
    		self.spectrogram, self.duration = wave_to_spectrogram(
    			self.waveform, 
    			self.SR, 
    			self.NFFT, 
    			self.HOP_LENGTH, 
    			self.CLIPPING)
    		self.pcen, _ = wave_to_spectrogram(
    			self.waveform, 
    			self.SR, 
    			self.NFFT, 
    			self.HOP_LENGTH, 
    			self.CLIPPING,
    			as_pcen=True)
    
    	def load_contours_file(self):
    		"""
    		A function to import the contours saved from the interface.
    
    		...
    
    		Returns
    		-------
    		contours : dict
    			Data contained in json file.
    		"""
    		with open(self.jsonfile_name, "r") as f:
    			contours = json.load(f)
    		return contours
    
    	def display_image(self, img="spec"):
    		fig, ax = plt.subplots(figsize=(16,9))
    		ax.imshow(
    			self.pcen[::-1] if img=="pcen" else self.spectrogram[::-1], 
    			cmap=self.cmap, 
    			interpolation='nearest', aspect='auto',
    			extent=(0, self.duration, 0, self.SR/2))
    
    		ax.set_xlabel("Time (in sec)")
    		ax.set_ylabel("Frequencies (in Hz)")
    		ax.set_title(f"Spectrogram of {os.path.basename(self.wavefile_name)}")
    		return fig, ax
    
    	def display_contours(self, mode="curves", img="spec"):
    		"""
    		A function to show the results of annotations, after using the interface.
    
    		...
    
    		Parameters
    		----------
    		mode : str
    			Wether to plot curves or straight lines between each point.
    			"curves" or any other string for straight lines. Default is "curves".
    
    		Returns
    		-------
    		None. Plots the contours fetched from a jsonfile onto a specgram.
    		"""
    		fig, ax = self.display_image(img)
    
    		for idx, key in enumerate(list(self.coords.keys())):
    			if mode=="curves":
    				cx, cy = to_curve(
    					np.array(self.coords[key])[:,0],
    					np.array(self.coords[key])[:,1],
    					kind="quadratic")
    				ax.plot(cx, cy, color=self.colors[idx%len(self.colors)])
    
    			else:
    				ax.plot(
    					np.array(self.coords[key])[:,0],
    					np.array(self.coords[key])[:,1], 
    					linestyle="-", color=self.colors[idx%len(self.colors)])
    
    			ax.plot(
    				np.array(self.coords[key])[:,0],
    				np.array(self.coords[key])[:,1], 
    				marker="s", mfc="white", linestyle="",
    				color=self.colors[idx%len(self.colors)])
    
    		plt.show()
    
    	def display_as_BB(self, img="spec", tol=1/100):
    		fig, ax = self.display_image(img)
    
    		for idx, key in enumerate(list(self.coords.keys())):
    			min_min = (min(np.array(annot_data.coords[key])[:,0]), 
    				min(np.array(annot_data.coords[key])[:,1]))
    			max_max = (max(np.array(annot_data.coords[key])[:,0]), 
    				max(np.array(annot_data.coords[key])[:,1]))
    
    			min_min = (min_min[0]-tol*min_min[0], min_min[1]-tol*min_min[1])
    			max_max = (max_max[0]+tol*max_max[0], max_max[1]+tol*max_max[1])
    
    			ax.add_patch(Rectangle(
    				min_min, 
    				max_max[0]-min_min[0],
    				max_max[1]-min_min[1],
    				facecolor='none',
    				edgecolor=self.colors[idx%len(self.colors)],
    				lw=1))
    
    			ax.text(min_min[0], max_max[1], key,
    				color="white",
    				bbox=dict(
    					boxstyle='square, pad=0', 
    					fc=self.colors[idx%len(self.colors)],
    					ec='none'))
    
    		plt.show()
    
    ##### EXAMPLE #####
    if __name__ == '__main__':
    
    	annot_data = Results(os.path.join(
            ".",
            "audio_examples",
            "SCW1807_20200713_064554.wav"),
        os.path.join(
            ".",
            "outputs",
            "SCW1807_20200713_064554-contours.json"))
    
    	annot_data.display_contours() # or annot_data.display_contours(img="pcen")