In this article we will show case a useful piece of code that was developed during a certain side-project of mine. The Project I was working on required an Interactive Scatter Plot that the user could interact it with in many ways. This includes features like being able to mass-select and highlight individual scatter points on the group, and performing actions on the highlighted points (e.g. delete, copy, move).
This article mainly serves to share the core parts of this code in the form of a class, which our readers may copy and use in their own programs. Important points of interest in the code will also be explained.
The Highlighter Class
We will start by creating the Highlighter class in a separate file. The creation of each method will be discussed one by one.
import pandas as pd
import numpy as np
from matplotlib.widgets import RectangleSelector
import keyboard
class Highlighter(object):
def __init__(self, canvas, ax, x, y):
self.canvas = canvas
self.ax = ax
self.x, self.y = x, y
self.mask = np.zeros(x.shape, dtype=bool)
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
self.selector = RectangleSelector(self.ax, self.select, useblit=True, state_modifier_keys={"center": "alt"})
The core part of the Highlighter class is the Rectangle Selector, which is a matplotlib feature, used to create a drag tool. This is only useful in creating the UI for the drag tool however, the actual selection, highlighting, and the rest of the features will come from us.
The other important part of this function is the self.mask
attribute, which will be a list of 0s and 1s values of the same length as the number of points in our scatter plot. This will track whether our points are currently selected or not. A value of 1 at the index-2 of this list, means that the third scatter point in our plot is currently highlighted. Multiple points may be selected at the same time.
def update(self, x, y):
self.x, self.y = x, y
self.mask = np.zeros(x.shape, dtype=bool)
Next is the update
function, which is called whenever we want to update our scatter plot with new values (or remove some values). We update the x and y values in the highlighter class, and resize the mask to the new number of points in our scatter plot. Note, this does not effect the actual scatter plot, whose code will be elsewhere in the main application.
def select(self, event1, event2):
prevOffsets = []
prevMask = None
if keyboard.is_pressed('ctrl'):
prevOffsets = self._highlight.get_offsets()
prevMask = self.mask
self.clear_highlights()
self.mask |= self.inside(event1, event2)
xy = np.column_stack([self.x[self.mask], self.y[self.mask]])
if len(prevOffsets) > 0:
xy = np.concatenate((xy, prevOffsets))
self.mask |= prevMask
if len(xy):
self._highlight.set_offsets(xy)
self.canvas.draw_idle()
The most important function is the select
method, which is called by the rectangle selector every time we use the drag tool created by the rectangle selector. The reason this will be called is because we passed the select
function of the highlighter class into the rectangle selector’s constructor as the second parameter.
This function does the following:
- Tracks the previous highlighted values, so that if the CTRL key is pressed, multi-stage drag selection can take place.
- Calls the
inside
method, which returns a list of boolean values representing which points are inside the selected area, and which are not. The maskattribute
is assigned these values. The inside method calculates this list using the two events, the first event being the start location of the drag (where the mouse was held down) and the end location (where the mouse was released).
def inside(self, event1, event2):
"""Returns a boolean mask of the points inside the rectangle defined by
event1 and event2."""
x0, x1 = sorted([event1.xdata, event2.xdata])
y0, y1 = sorted([event1.ydata, event2.ydata])
mask = ((self.x > x0) & (self.x < x1) &
(self.y > y0) & (self.y < y1))
return mask
- Filter the x and y values using the mask, and store them in the
xy
variable which will become the coordinates of the points which need to be highlighted. At the end, theset_offsets
method is called, which will handle the actual highlighting.
def set_offsets(self, xy):
self._highlight.remove()
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
if len(xy):
self._highlight.set_offsets(xy)
self.canvas.draw_idle()
This is accomplished by drawing yellow scatter points over the locations of the points we want to select. Alternatively, you could modify the code to instead update the selected scatter points individually and change their color to yellow.
def clear_highlights(self):
self._highlight.remove()
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
self.mask = np.zeros(self.x.shape, dtype=bool)
self.canvas.draw_idle()
There is also a clear_highlights
method that has been used within this class to clear the existing highlights.
Highlighter Class – Complete Code
Here is the complete combined code for the Highlighter class.
import pandas as pd
import numpy as np
from matplotlib.widgets import RectangleSelector
import keyboard
class Highlighter(object):
def __init__(self, canvas, ax, x, y, ):
self.canvas = canvas
self.ax = ax
self.x, self.y = x, y
self.mask = np.zeros(x.shape, dtype=bool)
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
self.selector = RectangleSelector(self.ax, self.select, useblit=True, state_modifier_keys={"center": "alt"})
def update(self, x, y):
self.x, self.y = x, y
self.mask = np.zeros(x.shape, dtype=bool)
def select(self, event1, event2):
prevOffsets = []
prevMask = None
if keyboard.is_pressed('ctrl'):
prevOffsets = self._highlight.get_offsets()
prevMask = self.mask
self.clear_highlights()
self.mask |= self.inside(event1, event2)
xy = np.column_stack([self.x[self.mask], self.y[self.mask]])
if len(prevOffsets) > 0:
xy = np.concatenate((xy, prevOffsets))
self.mask |= prevMask
if len(xy):
self._highlight.set_offsets(xy)
self.canvas.draw_idle()
def clear_highlights(self):
self._highlight.remove()
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
self.mask = np.zeros(self.x.shape, dtype=bool)
self.canvas.draw_idle()
def set_offsets(self, xy):
self._highlight.remove()
self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
if len(xy):
self._highlight.set_offsets(xy)
self.canvas.draw_idle()
def check_inside(self, point1, point2):
"""Returns a boolean mask of the points inside the rectangle defined by
event1 and event2."""
x0, x1 = sorted([point1[0], point2[0]])
y0, y1 = sorted([point1[1], point2[1]])
mask = ((self.x > x0) & (self.x < x1) &
(self.y > y0) & (self.y < y1))
return mask
def inside(self, event1, event2):
"""Returns a boolean mask of the points inside the rectangle defined by
event1 and event2."""
x0, x1 = sorted([event1.xdata, event2.xdata])
y0, y1 = sorted([event1.ydata, event2.ydata])
mask = ((self.x > x0) & (self.x < x1) &
(self.y > y0) & (self.y < y1))
return mask
The GUI application
Here is the code for our driver application. We developed an application with Tkinter GUI to wrap around our Matplotlib graph. You are not required to do this, and can modify the code accordingly (won’t take more than a few small changes). It is expected however, that any graph application that requires such an advanced feature, will also have a GUI library along with it (since matplotlib GUI is limited).
The only method of note here, is the delete method which handles the deletion of the selected points. You can use the delete method as a template for performing other actions, such as copy, paste, move, etc.
import tkinter as tk
from tkinter import ttk
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from Highlighter import Highlighter
import pandas as pd
import random
class MatplotlibApp:
def __init__(self, master):
self.master = master
self.master.title("Matplotlib App")
self.data = pd.DataFrame(columns=["x", "y"])
self.points = []
self.create_widgets()
def create_widgets(self):
self.fig = Figure(figsize=(5, 4), dpi=100)
self.ax = self.fig.add_subplot(111)
self.ax.set_xlim(-1, 11)
self.ax.set_ylim(-1, 11)
self.canvas = FigureCanvasTkAgg(self.fig, master=self.master)
self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
self.canvas.get_tk_widget().bind("<Delete>", self.delete)
self.canvas.draw()
tk.Button(self.master, text="Add Random Point", command=self.add).pack(padx=20, pady=20)
self.highlighter = Highlighter(self.canvas, self.ax, self.data["x"], self.data["y"])
def add(self):
x, y = [random.randint(0, 10), random.randint(0, 10)]
df = pd.DataFrame([[x, y]], columns=["x", "y"])
self.data = pd.concat([self.data, df], ignore_index=True)
self.points.append(self.ax.scatter(x, y, color="blue"))
self.highlighter.update(self.data["x"], self.data["y"])
self.canvas.draw()
def delete(self, event):
self.selected_regions = self.highlighter.mask
self.data = self.data[~self.selected_regions].reset_index(drop=True)
for i, artist in enumerate(self.points):
if self.selected_regions[i]:
artist.remove()
self.points = [artist for artist, m in zip(self.points, self.selected_regions) if m != 1]
self.highlighter.update(self.data["x"], self.data["y"])
self.highlighter.clear_highlights()
self.canvas.draw()
def main():
root = tk.Tk()
app = MatplotlibApp(root)
root.mainloop()
if __name__ == "__main__":
main()
This marks the end of the Interactive Scatter Plot with Matplotlib Rectangle Selector Tutorial. Any suggestions or contributions for CodersLegacy are more than welcome. Questions regarding the tutorial content may be asked in the comments section below.