Skip to content
Snippets Groups Projects
Select Git revision
  • 3614416c7bb8fa3a0b0b537918a585beb4925367
  • main default protected
2 results

2-manual_verification.py

Blame
  • lite_line_clicker.py 21.00 KiB
    """
    Add default parameters for keys and mousebutton so it can be customized
    """
    
    ##### IMPORTATIONS ######
    import warnings
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.backend_bases import MouseButton
    import matplotlib.lines as mlines
    
    from scipy.interpolate import interp1d
    
    ##### FUNCTIONS #####
    def to_curve(x, y, kind="quadratic", precision=10):
    	"""
    		A function to compute a curvy line between two points.
    		It interpolates a line with 100 points.
    
    		...
    
    		Parameters
    		----------
    		x : list or numpy array
    			List of coordinates on x axis.
    		y : list or numpy array
    			List of coordinates on y axis.
    		kind : string
    			The interpolation method to use. Can be any method in : ‘linear’, 
    			‘nearest’, ‘nearest-up’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’,
    			‘previous’, or ‘next’.
    		precision : number of segments between each point.
    		
    		Returns
    		-------
    		xi : numpy array
    			List of the coordinates of the curve on x-axis.
    		yi : numpy array
    			List of the coordinates of the curve on y-axis.
    	"""
    
    	y = y[np.argsort(x)]
    	x = np.sort(x)
    
    	f = interp1d(x,y, kind=kind)
    
    	xi = np.linspace(x.min(), x.max(), precision * (len(x)-1))
    	yi = f(xi)
    
    	return xi, yi
    
    ##### CLASSES #####
    class clicker(object):
    	"""
    		A Class that is an "add-on" for an existing matplotlib figure axis.
    		Added to an axis, it allows to draw lines from mouse clicks. 
    		Points can be removed using the right click.
    
    		Different categories can be selected using the generated legend : 
    		by clicking on a line or by using up/down arrows.
    		It is also possible to add more categories using with 'ctrl+a' shortcut
    		
    		...
    
    		Params
    		----------
    		axis : matplotlib object
    			Axis instance.
    		bspline : string or boolean, optional.
    			If False, app will plot points linked by lines.
    			If string, app will plot points linked by curve.
    			For curves, use 'cubic' or 'quadratic'
    		colors : list, optional.
    			List of color (hexadecimal) to cycle over when plotting lines.
    			Default is the default list of colors from matplotlib.
    		legend_bbox : tuple, optional.
    			The position of the legend (x-axis, y-axis) relative to the figure.
    		marker : dict, optional.
    			Additional parameters to change the look of the lines.
    			Default is None, lines will be solid with points as white dots.
    		maxlines : int, optional.
    			Maximum number of labels in legend.
    			Default is 30.
    		names : str or list, optional.
    			Names for legend labels.
    			If string, names will be name+{line number}
    			If list, when a category is added, it will have the name "DefaultNameX".
    			Default is "Line".
    		n_names : int, optional.
    			Number of categories at launch when name is a string.
    			Default is 10.
    		pick_dist : int, optional.
    			Distance to legend category for picking mouse.
    			Default is 10.
    		wraplines : int, optional.
    			Maximum number of labels in legend before creating a new column.
    			Default is 15.
    			Default is (1, 0.5)
    		coords : dict, optional.
    			Coordinates of each point for each category (line, key of the dict).
    			Coordinates are expressed in (time, frequency) 
    			or (pixel, pixel) if no extent given.
    			Default is {}. If given, initates figure with annotations.
    
    
    		Attributes
    		----------
    		current_line : int
    			Index of the ctegory (line) currently selected
    		figure : matplotlib object
    			Abbreviation for axis.figure.
    		figure_bsplines : list of matplotlib 2D lines
    			Secondary lines, in use if bspline mode is selected. 
    			Draws curves instead of straight lines.
    		figure_lines : list of matplotlib 2D lines
    			Main lines, displayed on image as straight lines if bspline is False.
    			If bspline is true, linestyle is removed and they are displayed as dots.
    		key_press : matplotlib event
    			Information on the last key_press event.
    		legend : matplotlib object
    			legend of the plot
    		legend_labels : list of str
    			labels of the categories in legend.
    		linestyle : str
    			If bspline is True, saves linestyle.
    		mouse_press : matplotlib event
    			Information on the last mouse click in matplotlib canvas.
    		pick_press : matplotlib event
    			Information on the last mouse click event in matplotlib legend.
    		pressed : bool
    			Is used to check if a specific mouse button is pressed or not.
    
    		Other attributes have self explenatory names.
    
    		Methods
    		-------
    		add_category(show):
    			Adds a category to the list in legend. 
    		add_point(x,y):
    			Adds a point to the category in focus at the given coordinates. 
    			Focus is updated to be on this new category.
    		clear_category():
    			Clears all points from the category in focus.
    		double_lines():
    			Creates a copy of the current set of lines from figure_lines.
    			It is used to interpolate curves between figure_lines points.
    		distance_mouse(points, mouse_x, mouse_y):
    			Gets the distance bewteen a mouse position and the currently
    			selected category (points).
    		get_focus(event):
    			Puts focus on a category in legend. Allows the user to interact with it.
    		get_key_event(event):
    			Activates functions when specific keyboard keys are pressed.
    		get_mouse_press(event, show=True):
    			Activates functions when specific mouse buttons are pressed.
    		move_point(event):
    			Moves a selected point of the currently selected category to a new
    			position (event). Updates figure accordingly.
    		rm_point(x,y):
    			Removes a point from the category in focus at the given coordinates. 
    		set_legend():
    			Adds legend to axis given.
    		switch_line(arrow):
    			Puts focus on previous/next category in legend based on arrow input.
    		update_lines():
    			Updates data in lines. Called when coords is changed.
    
    		References
    		----------
    		Heavily inspired from mpl_point_clicker :
    			mpl-point-clicker.readthedocs.io/
    	"""
    
    	DEFAULT_marker = {		# Markerstyle for categories.
    	"marker":"o",
    	"mfc":"white",
    	"linestyle":"-",
    	}
    
    	DEFAULT_colors = [	# Colors for categories. Will cycle through them.
    	'#1f77b4',			# List can be appened or reduced.
    	'#ff7f0e',
    	'#2ca02c',
    	'#d62728',
    	'#9467bd',
    	'#8c564b',
    	'#e377c2',
    	'#7f7f7f',
    	'#bcbd22',
    	'#17becf'
    	]
    
    	DEFAULT_param = {
    	"add_point" : MouseButton.LEFT,
    	"rm_point" : MouseButton.RIGHT,
    	"move_point" : MouseButton.MIDDLE,
    	"add_category" : "A", # shift+a
    	"clear_category" : "R", # shift+r
    	"focus_up" : "up",
    	"focus_down" : "down",
    	}
    
    
    	def __init__(
    		self, 
    		axis,
    		marker=None,
    		bspline=False,
    		colors=None,
    		names="Line",
    		maxlines=30,
    		wraplines=15,
    		legend_bbox=(1,0.5),
    		pick_dist=10,
    		n_names=10,
    		coords={}):
    
    		# Variable assignement
    		self.axis = axis
    		self.legend_bbox = legend_bbox
    		self.pick_dist = pick_dist
    		self.param = self.DEFAULT_param.copy()
    		self.wait_before_interpolation = 2
    		self.bspline = bspline
    
    		if isinstance(marker, dict):
    			self.marker = marker
    		else:
    			self.marker = self.DEFAULT_marker.copy()
    
    		if isinstance(colors, list):
    			self.colors = colors
    		else:
    			self.colors = self.DEFAULT_colors.copy()
    
    		if isinstance(names, str):
    			self.names = names
    			coord_keys = [self.names + str(i+1) for i in range(n_names)]
    		elif isinstance(names, list):
    			coord_keys = names.copy()
    			self.names = "DefaultName"
    		else:
    			self.names = self.DEFAULT_name
    			coord_keys = [self.names + str(i+1) for i in range(n_names)]
    		
    		if isinstance(coords, dict) and coords != {}:
    			self.coords = coords
    		else:
    			self.coords = {key: [] for key in coord_keys}
    
    		if isinstance(maxlines, int):
    			self.maxlines = maxlines
    		else: 
    			self.maxlines = self.DEFAULT_maxlines
    		
    		if isinstance(wraplines, int):
    			self.wraplines = wraplines
    		else:
    			self.wraplines = self.DEFAULT_wraplines
    
    		# Line creation
    		self._create_lines_and_points()
    
    		# Drawing lines
    		self.figure = self.axis.figure
    		self.current_line = 0
    		self._set_plot()
    
    		# Linking actions in matplotlib canvas
    		self.mouse_press_event = self.figure.canvas.mpl_connect(
    			'button_press_event', 
    			self.get_mouse_press)
    		self.mouse_release_event = self.figure.canvas.mpl_connect(
    			'button_release_event', 
    			self.get_mouse_press)
    		self.activate_move = False
    
    		self.pick_event = self.figure.canvas.mpl_connect(
    			'pick_event', 
    			self.get_focus)
    		self.key_press = self.figure.canvas.mpl_connect(
    			'key_press_event', 
    			self.get_key_event)
    
    		# callbacks
    		self.xlims = self.axis.get_xlim()
    		self.ylims = self.axis.get_ylim()
    		self.axis.callbacks.connect('xlim_changed', self._on_xlims_change)
    		self.axis.callbacks.connect('ylim_changed', self._on_ylims_change)
    
    	def _on_xlims_change(self, event_ax):
    		self.xlims = event_ax.get_xlim()
    		self.update_lines()
    
    	def _on_ylims_change(self, event_ax):
    		self.ylims = event_ax.get_ylim()
    
    
    	def _create_lines_and_points(self):
    		self.lines = []
    		self.points = []
    		for idx, legend_label in enumerate(self.coords.keys()):
    			if len(self.coords[legend_label]) > 0:
    				self.points += [
    					mlines.Line2D(
    						np.array(self.coords[legend_label])[:,0], 
    						np.array(self.coords[legend_label])[:,1], 
    						label=legend_label, 
    						color=self.colors[idx%len(self.colors)],
    						zorder=3,
    						linestyle="",
    						**{x: self.marker[x] for x in self.marker if x not in {"linestyle"}}) 
    				]
    				self.lines += [
    					mlines.Line2D(
    						np.array(self.coords[legend_label])[:,0], 
    						np.array(self.coords[legend_label])[:,1], 
    						label=legend_label, 
    						zorder=2,
    						color=self.colors[idx%len(self.colors)]) 
    				]
    			else:
    				self.points += [
    					mlines.Line2D(
    						[], 
    						[], 
    						label=legend_label, 
    						color=self.colors[idx%len(self.colors)],
    						zorder=3,
    						linestyle="",
    						**{x: self.marker[x] for x in self.marker if x not in {"linestyle"}})  
    				]
    				self.lines += [
    					mlines.Line2D(
    						[], 
    						[], 
    						label=legend_label, 
    						zorder=2,
    						color=self.colors[idx%len(self.colors)])
    				]
    
    	def _set_plot(self):
    		"""
    			A method to create matplotlib.pyplot legend and draw ut.
    			Legend contains empty lines and can be clicked.
    			
    			...
    
    			Returns
    			-------
    			None : updates axis, creates legend.
    			"""
    		# Make some space to include legend
    		scale = (19-((len(self.coords.keys())//self.wraplines)+1))/19
    		self.axis.set_position([0, 0, scale, 1])
    
    		# Add legend to plot
    		self.legend = self.axis.legend(
    			loc="center left", 
    			bbox_to_anchor=self.legend_bbox,
    			ncol=(len(self.coords.keys())//self.wraplines)+1,
    			title="Selection of lines",
    			handles=self.lines)
    
    		# Add lines and points to plot
    		for line, point in zip(self.lines, self.points):
    			self.axis.add_line(point)
    			self.axis.add_line(line)
    
    		for legend_l, line in zip(self.legend.get_lines(), self.lines):
    			legend_l.set_picker(True)
    			legend_l.set_pickradius(self.pick_dist)
    			legend_l.set_alpha(0.2)
    			self.axis.add_line(line)
    
    		# Focus on selected line
    		self.legend.get_lines()[self.current_line].set_alpha(1)
    
    	def get_mouse_press(self, event):
    		"""
    			A Method that retrieves mouse interactions with matplotlib plot.
    
    			...
    
    			Parameters
    			----------
    			event : matplotlib object
    				Contains 3 attributes : button pressed, x and y at that time.
    			
    			Returns
    			-------
    			None : is used to trigger methods add_points and rm_points.
    		"""
    		if self.figure.canvas.widgetlock.available(self):
    			pressed, x, y = event.button, event.xdata, event.ydata
    			if ((pressed is self.param["add_point"]) and
    				(isinstance(x, float)) and
    				(isinstance(y, float)) and
    				event.name == 'button_press_event'):
    				self.add_point(x, y)
    
    			elif ((pressed is self.param["rm_point"]) and
    				(isinstance(x, float)) and
    				(isinstance(y, float)) and
    				event.name == 'button_press_event'):
    				self.rm_point(x, y)
    
    			elif ((pressed is self.param["move_point"]) and
    				(isinstance(x, float)) and
    				(isinstance(y, float)) and
    				event.name == 'button_press_event'):
    				distances = self.distance_mouse(
    					np.array(self.coords[list(self.coords.keys())[self.current_line]]),
    					event.xdata, 
    					event.ydata)
    				self.point_to_move = np.argmin(distances)
    				dist = np.min(distances)
    				if dist < 0.5:
    					self.activate_move = True
    					self.move_point(event)
    					self.motion_event = self.figure.canvas.mpl_connect(
    						'motion_notify_event', 
    						self.move_point)
    
    			elif ((pressed is self.param["move_point"]) and
    				(isinstance(x, float)) and
    				(isinstance(y, float)) and
    				event.name == 'button_release_event'):
    				if self.activate_move:
    					self.figure.canvas.mpl_disconnect(self.motion_event)
    				self.activate_move = False
    				self.point_to_move = False
    				
    	def update_lines(self):
    		"""
    			Updates data in lines. Called when coords is changed.
    
    			...
    
    			Returns
    			-------
    			None : Updates figure lines
    		"""
    		for i_line in range(len(self.coords)):
    			line_coords = np.array(self.coords[list(self.coords.keys())[i_line]])
    
    			if (line_coords.shape[0] > 0):			
    				# [ ] MAIN OPTIMIZATION : check if line is in visible frame
    				in_xframe = (
    					np.any(line_coords[:,0] >= min(self.xlims)) and 
    					np.any(line_coords[:,0] <= max(self.xlims)))
    				in_yframe = (
    					np.any(line_coords[:,1] >= min(self.ylims)) and 
    					np.any(line_coords[:,1] <= max(self.ylims)))
    
    				if in_xframe and in_yframe:
    					self.points[i_line].set_data(
    						np.array(line_coords)[:,0],
    						np.array(line_coords)[:,1])
    
    					if (self.bspline) and (len(self.coords[list(self.coords.keys())[i_line]]) > self.wait_before_interpolation):
    						curvex, curvey = to_curve(
    							line_coords[:,0],
    							line_coords[:,1],
    							kind="quadratic")
    						self.lines[i_line].set_data(curvex, curvey)
    					else:
    						self.lines[i_line].set_data(
    							np.array(line_coords)[:,0],
    							np.array(line_coords)[:,1])
    				else: 
    					self.points[i_line].set_data([], [])	
    					self.lines[i_line].set_data([], [])	
    
    			else:
    				self.points[i_line].set_data([], [])	
    				self.lines[i_line].set_data([], [])	
    
    		self.figure.canvas.draw_idle()
    
    	def distance_mouse(self, points, mouse_x, mouse_y):
    		"""
    		Gets the distance bewteen a mouse position
    		and the currently selected category (points).
    
    		...
    		Parameters
    		----------
    		points : numpy array
    			A list of coordinates with shape (2,n).
    		mouse_x : int or float
    			Coordinate of the mouse on x-axis.
    		mouse_y : int or float
    			Coordinate of the mouse on y-axis.
    		Returns
    		-------
    		distances : numpy array
    			Distances of each point to mous coordinates.
    		"""
    		size_x, size_y = self.figure.get_size_inches()
    		max_x, max_y=self.axis.get_xbound()[1],self.axis.get_ybound()[1]
    
    		# look for closest coordinates of the category currently selected.
    		distances = [np.linalg.norm([(point[0]/max_x)*size_x, 
    			(point[1]/max_y)*size_y] - np.array([(mouse_x/max_x)*size_x, 
    				(mouse_y/max_y)*size_y])) 
    			for point in points]	 
    
    		return distances
    
    	def add_point(self, x, y):
    		"""
    			A method to add a point on plot at given coordinates.
    
    			...
    
    			Parameters
    			----------
    			x : float
    				x-axis coordinate.
    			y : float
    				y-axis coordinate.
    
    			Returns
    			-------
    			None : updates coords, figure_lines, figure.
    		"""
    
    		if len(self.coords[list(self.coords.keys())[self.current_line]])==0:
    			self.coords[list(self.coords.keys())[self.current_line]] += [[x, y]]
    		else:
    			if x in np.array(self.coords[list(self.coords.keys())[self.current_line]])[:,0]:
    				warnings.warn("Cannot place two points at the same timestamp!", UserWarning, stacklevel=2)
    			else:
    				# where should it be inserted ?
    				here = np.where(
    					np.array(
    						self.coords[
    							list(self.coords.keys())[self.current_line]
    							]
    						)[:,0] > x
    					)[0]
    
    				if len(here) != 0:
    					self.coords[list(self.coords.keys())[self.current_line]].insert(here[0] ,[x, y])
    				else:
    					self.coords[list(self.coords.keys())[self.current_line]] += [[x, y]]
    		
    		self.update_lines()
    
    	def rm_point(self, x, y):
    		"""
    		A method to remove closest point in plot to given coordinates.
    
    		...
    
    		Parameters
    		----------
    		x : float
    			x-axis coordinate.
    		y : float
    			y-axis coordinate.
    
    		Returns
    		-------
    		None : updates coords, figure_lines, figure.
    		"""
    		list_coords = np.array(
    				self.coords[
    					list(self.coords.keys())[self.current_line]
    					]
    				)
    		
    		distances = self.distance_mouse(list_coords, x, y)
    
    		if len(distances) > 0:
    			if min(distances) < 1:
    				if len(list_coords) > 1:
    					# remove closest point of selected
    					list_coords = np.delete(list_coords, np.argmin(distances), axis=0)
    					self.coords[list(self.coords.keys())[self.current_line]] = list_coords.tolist()
    				else:
    					self.coords[list(self.coords.keys())[self.current_line]] = []
    
    		self.update_lines()
    
    	def move_point(self, event):
    		"""
    			Moves a selected point of the currently selected category to a new
    			position (event). Updates figure accordingly.
    
    			...
    
    			Parameters
    			----------
    			event : matplotlib object
    				Matplotlib event containig information on positions.
    
    			Returns
    			-------
    			None : it updates the coordinates of a point in a category 
    			and re-draws the figure.
    		"""
    		if (self.activate_move and 
    			event.xdata != None and 
    			event.ydata != None):
    
    			# does this point already exists ?
    			if event.xdata in np.array(self.coords[list(self.coords.keys())[self.current_line]])[:,0]:
    				warnings.warn("Cannot place two points at the same timestamp!", UserWarning, stacklevel=2)
    			else:
    				# update coords
    				self.coords[list(self.coords.keys())[self.current_line]][self.point_to_move] = [event.xdata, event.ydata]
    				self.update_lines()
    
    	def get_key_event(self, event, show=True):
    		"""
    		A method that retrieve key interactions with matplotlib plot.
    
    		...
    
    		Parameters
    		----------
    		event : matplotlib object
    			Contains 3 attributes : button pressed, x and y at that time.
    		show : boolean
    			Parameter passed to add_category()
    		
    		Returns
    		-------
    		None : is used to trigger other functions.
    		"""
    		key = event.key
    		if ((key == self.param["add_category"]) and 
    			(len(self.coords) < self.maxlines)):
    			self.add_category(show)
    
    		elif ((key == self.param["clear_category"]) and
    			(len(self.coords) > 1)):
    			self.clear_category()
    
    		elif ((key == self.param["focus_up"]) and
    			(self.current_line != 0)):
    			self.switch_line(-1)
    
    		elif ((key == self.param["focus_down"]) and
    			(self.current_line != len(self.coords)-1)):
    			self.switch_line(1)
    
    	def add_category(self, show):
    		"""
    		A method to add a line (therefore a new category) to the plot.
    		Also change current focus to be on the newly created category.
    
    		...
    
    		Parameters
    		----------
    		show : boolean
    			Shows change in legend.
    
    		Returns
    		-------
    		None : updates legend_labels, figure_lines, legend, figure 
    		and current_line.
    		"""
    		# make new name
    		self.coords[self.names + str(len(self.coords)+1)] = []
    
    		# add to points and lines
    		self.points += [
    			mlines.Line2D(
    				[], 
    				[], 
    				label=list(self.coords.keys())[-1], 
    				color=self.colors[(len(self.coords)-1)%len(self.colors)],
    				zorder=3,
    				linestyle="",
    				**{x: self.marker[x] for x in self.marker if x not in {"linestyle"}})  
    		]
    		self.lines += [
    			mlines.Line2D(
    				[], 
    				[], 
    				label=list(self.coords.keys())[-1], 
    				zorder=2,
    				color=self.colors[(len(self.coords)-1)%len(self.colors)])
    		]
    
    		self._set_plot()
    		self.update_lines()
    
    		# focus auto on new category
    		self.current_line = len(self.coords)-1
    
    		for legend_line in self.legend.get_lines():
    			legend_line.set_alpha(0.2)
    		self.legend.get_lines()[self.current_line].set_alpha(1)
    		if show:
    			self.figure.canvas.draw_idle()
    
    	def clear_category(self):
    		"""
    		A method to remove a line (therefore a whole category) from the plot.
    		Removes the category that is in focus.
    
    		...
    
    		Returns
    		-------
    		None : updates legend_labels, figure_lines, legend, figure.
    		"""
    
    		# we only need to remove data from coords
    		self.coords[list(self.coords.keys())[self.current_line]] = []
    		self.points[self.current_line].set_data([], [])
    		self.lines[self.current_line].set_data([], [])
    
    		# update plot
    		self.update_lines()
    
    	def switch_line(self, arrow):
    		"""
    		Puts focus on previous/next category in legend based on arrow input.
    
    		...
    
    		Parameters
    		----------
    		arrow : int
    			Can be 1 or -1. Change the index of the current category in focus.
    		
    		Returns
    		-------
    		None : updates currentline, legend and figures attributes.
    		"""
    		# new focus
    		self.current_line += arrow
    
    		# adapt alpha
    		for legend_line in self.legend.get_lines():
    			legend_line.set_alpha(0.2)
    		self.legend.get_lines()[self.current_line].set_alpha(1)
    		self.figure.canvas.draw_idle()	
    
    	def get_focus(self, event):
    		"""
    		A method to highlight a given marker in legend.
    		Points can be added/removed only to the category in focus.
    		
    		...
    
    		Parameters
    		----------
    		event : matplotlib object
    			Is used to acces event.artist attribute.
    		
    		Returns
    		-------
    		None : updates figure, current_line and legend attributes.
    		"""
    		# set all legend lines to alpha = 0.2
    		for legend_line in self.legend.get_lines():
    			legend_line.set_alpha(0.2)
    		# set legend line in focus to alpha = 1
    		selected_legend = event.artist
    		current_alpha = selected_legend._alpha
    		selected_legend.set_alpha(1.0 if (current_alpha==0.2) else 0.2)
    		self.figure.canvas.draw_idle()	
    
    		# new focus
    		self.current_line = int(np.where(
    			np.array(self.legend.get_lines()) == event.artist)[0])
    
    
    ##### MAIN #####
    if __name__ == '__main__':
    	# dummy example
    	img = np.tile([[0.45, 0.55],[0.55, 0.45]], (25,25))
    	img[0][0] = 0
    	img[-1][-1] = 1
    
    	fig, ax = plt.subplots(figsize=(16, 9))
    	ax.imshow(img, cmap="gray")
    	base = clicker(axis=ax, bspline="quadratic")
    	plt.show(block=True)