Interactive Scatter Plot Highlighting and Deletion with Matplotlib

In this article we will show case a useful piece of code that was developed during a certain side-project of mine. The Project I was working on required an Interactive Scatter Plot that the user could interact it with in many ways. This includes features like being able to mass-select and highlight individual scatter points on the group, and performing actions on the highlighted points (e.g. delete, copy, move).

This article mainly serves to share the core parts of this code in the form of a class, which our readers may copy and use in their own programs. Important points of interest in the code will also be explained.


The Highlighter Class

We will start by creating the Highlighter class in a separate file. The creation of each method will be discussed one by one.

Python
import pandas as pd
import numpy as np
from matplotlib.widgets import RectangleSelector
import keyboard

class Highlighter(object):
    def __init__(self, canvas, ax, x, y):
        self.canvas = canvas
        self.ax = ax
        self.x, self.y = x, y
        self.mask = np.zeros(x.shape, dtype=bool)

        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
        self.selector = RectangleSelector(self.ax, self.select, useblit=True, state_modifier_keys={"center": "alt"})

The core part of the Highlighter class is the Rectangle Selector, which is a matplotlib feature, used to create a drag tool. This is only useful in creating the UI for the drag tool however, the actual selection, highlighting, and the rest of the features will come from us.

Interactive Scatter Plot Highlighting and Deletion with Matplotlib

The other important part of this function is the self.mask attribute, which will be a list of 0s and 1s values of the same length as the number of points in our scatter plot. This will track whether our points are currently selected or not. A value of 1 at the index-2 of this list, means that the third scatter point in our plot is currently highlighted. Multiple points may be selected at the same time.

Python
    def update(self, x, y):
        self.x, self.y = x, y
        self.mask = np.zeros(x.shape, dtype=bool)

Next is the update function, which is called whenever we want to update our scatter plot with new values (or remove some values). We update the x and y values in the highlighter class, and resize the mask to the new number of points in our scatter plot. Note, this does not effect the actual scatter plot, whose code will be elsewhere in the main application.

Python
   def select(self, event1, event2):
        prevOffsets = []
        prevMask = None

        if keyboard.is_pressed('ctrl'):
            prevOffsets = self._highlight.get_offsets()
            prevMask = self.mask

        self.clear_highlights()
        self.mask |= self.inside(event1, event2)
        xy = np.column_stack([self.x[self.mask], self.y[self.mask]])

        if len(prevOffsets) > 0:
            xy = np.concatenate((xy, prevOffsets))
            self.mask |= prevMask

        if len(xy):
            self._highlight.set_offsets(xy)
            self.canvas.draw_idle()

The most important function is the select method, which is called by the rectangle selector every time we use the drag tool created by the rectangle selector. The reason this will be called is because we passed the select function of the highlighter class into the rectangle selector’s constructor as the second parameter.

This function does the following:

  1. Tracks the previous highlighted values, so that if the CTRL key is pressed, multi-stage drag selection can take place.
  2. Calls the inside method, which returns a list of boolean values representing which points are inside the selected area, and which are not. The mask attribute is assigned these values. The inside method calculates this list using the two events, the first event being the start location of the drag (where the mouse was held down) and the end location (where the mouse was released).
Python
    def inside(self, event1, event2):
        """Returns a boolean mask of the points inside the rectangle defined by
        event1 and event2."""
        x0, x1 = sorted([event1.xdata, event2.xdata])
        y0, y1 = sorted([event1.ydata, event2.ydata])
        mask = ((self.x > x0) & (self.x < x1) &
                (self.y > y0) & (self.y < y1))
        return mask
  1. Filter the x and y values using the mask, and store them in the xy variable which will become the coordinates of the points which need to be highlighted. At the end, the set_offsets method is called, which will handle the actual highlighting.
Python
    def set_offsets(self, xy):
        self._highlight.remove()
        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)

        if len(xy):
            self._highlight.set_offsets(xy)
        self.canvas.draw_idle()

This is accomplished by drawing yellow scatter points over the locations of the points we want to select. Alternatively, you could modify the code to instead update the selected scatter points individually and change their color to yellow.

Python
    def clear_highlights(self):
        self._highlight.remove()
        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
        self.mask = np.zeros(self.x.shape, dtype=bool)
        self.canvas.draw_idle()

There is also a clear_highlights method that has been used within this class to clear the existing highlights.


Highlighter Class – Complete Code

Here is the complete combined code for the Highlighter class.

Python
import pandas as pd
import numpy as np
from matplotlib.widgets import RectangleSelector
import keyboard

class Highlighter(object):
    def __init__(self, canvas, ax, x, y, ):
        self.canvas = canvas
        self.ax = ax
        self.x, self.y = x, y
        self.mask = np.zeros(x.shape, dtype=bool)

        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
        self.selector = RectangleSelector(self.ax, self.select, useblit=True, state_modifier_keys={"center": "alt"})

    def update(self, x, y):
        self.x, self.y = x, y
        self.mask = np.zeros(x.shape, dtype=bool)

    def select(self, event1, event2):
        prevOffsets = []
        prevMask = None

        if keyboard.is_pressed('ctrl'):
            prevOffsets = self._highlight.get_offsets()
            prevMask = self.mask

        self.clear_highlights()
        self.mask |= self.inside(event1, event2)
        xy = np.column_stack([self.x[self.mask], self.y[self.mask]])

        if len(prevOffsets) > 0:
            xy = np.concatenate((xy, prevOffsets))
            self.mask |= prevMask

        if len(xy):
            self._highlight.set_offsets(xy)
            self.canvas.draw_idle()
    
    def clear_highlights(self):
        self._highlight.remove()
        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)
        self.mask = np.zeros(self.x.shape, dtype=bool)
        self.canvas.draw_idle()

    def set_offsets(self, xy):
        self._highlight.remove()
        self._highlight = self.ax.scatter([], [], s=50, color='yellow', zorder=10)

        if len(xy):
            self._highlight.set_offsets(xy)
        self.canvas.draw_idle()

    def check_inside(self, point1, point2):
        """Returns a boolean mask of the points inside the rectangle defined by
        event1 and event2."""
        x0, x1 = sorted([point1[0], point2[0]])
        y0, y1 = sorted([point1[1], point2[1]])
        mask = ((self.x > x0) & (self.x < x1) &
                (self.y > y0) & (self.y < y1))
        return mask

    def inside(self, event1, event2):
        """Returns a boolean mask of the points inside the rectangle defined by
        event1 and event2."""
        x0, x1 = sorted([event1.xdata, event2.xdata])
        y0, y1 = sorted([event1.ydata, event2.ydata])
        mask = ((self.x > x0) & (self.x < x1) &
                (self.y > y0) & (self.y < y1))
        return mask

The GUI application

Here is the code for our driver application. We developed an application with Tkinter GUI to wrap around our Matplotlib graph. You are not required to do this, and can modify the code accordingly (won’t take more than a few small changes). It is expected however, that any graph application that requires such an advanced feature, will also have a GUI library along with it (since matplotlib GUI is limited).

The only method of note here, is the delete method which handles the deletion of the selected points. You can use the delete method as a template for performing other actions, such as copy, paste, move, etc.

Python
import tkinter as tk
from tkinter import ttk
from matplotlib.figure import Figure
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from Highlighter import Highlighter
import pandas as pd
import random 

class MatplotlibApp:
    def __init__(self, master):
        self.master = master
        self.master.title("Matplotlib App")
        self.data = pd.DataFrame(columns=["x", "y"])
        self.points = []
        self.create_widgets()

    def create_widgets(self):
        self.fig = Figure(figsize=(5, 4), dpi=100)
        self.ax = self.fig.add_subplot(111)
        self.ax.set_xlim(-1, 11)
        self.ax.set_ylim(-1, 11)    

        self.canvas = FigureCanvasTkAgg(self.fig, master=self.master)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
        self.canvas.get_tk_widget().bind("<Delete>", self.delete)
        self.canvas.draw()

        tk.Button(self.master, text="Add Random Point", command=self.add).pack(padx=20, pady=20)
        self.highlighter = Highlighter(self.canvas, self.ax, self.data["x"], self.data["y"])

    def add(self):
        x, y = [random.randint(0, 10), random.randint(0, 10)]
        df = pd.DataFrame([[x, y]], columns=["x", "y"])

        self.data = pd.concat([self.data, df], ignore_index=True)
        self.points.append(self.ax.scatter(x, y, color="blue"))
        self.highlighter.update(self.data["x"], self.data["y"])
        self.canvas.draw()

    def delete(self, event):
        self.selected_regions = self.highlighter.mask
        self.data = self.data[~self.selected_regions].reset_index(drop=True)

        for i, artist in enumerate(self.points):
            if self.selected_regions[i]:
                artist.remove()
        self.points = [artist for artist, m in zip(self.points, self.selected_regions) if m != 1]

        self.highlighter.update(self.data["x"], self.data["y"])
        self.highlighter.clear_highlights()
        self.canvas.draw()

def main():
    root = tk.Tk()
    app = MatplotlibApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

This marks the end of the Interactive Scatter Plot with Matplotlib Rectangle Selector Tutorial. Any suggestions or contributions for CodersLegacy are more than welcome. Questions regarding the tutorial content may be asked in the comments section below.

Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments