Mageek Mageek - 2 months ago 15
Python Question

matplotlib: make legend keys square

I am working with matplotlib and would like to change the keys in my legends to be squares instead of rectangles when I make, for example, bar plots. Is there a way to specify this?

What I have now:

enter image description here

What I want:

enter image description here

Thanks!

Answer

You can define your own legend keys.

The bar plot in my answer is created using the matplotlib barchart demo. (I have removed the error bars). The matplotlib legend guide explains how to define a class to replace legend keys with ellipses. I have modified that class to use squares (by using rectangle patches).

import numpy as np
from matplotlib.legend_handler import HandlerPatch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Define square (rectangular) patches
# that can be used as legend keys
# (this code is based on the legend guide example)

class HandlerSquare(HandlerPatch):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize, trans):
        center = xdescent + 0.5 * (width - height), ydescent
        p = mpatches.Rectangle(xy=center, width=height,
                               height=height, angle=0.0)
        self.update_prop(p, orig_handle, legend)
        p.set_transform(trans)
        return [p]      

# this example is the matplotlib barchart example:

N = 5
menMeans = (20, 35, 30, 35, 27)

ind = np.arange(N)  # the x locations for the groups
width = 0.35       # the width of the bars

fig, ax = plt.subplots()
rects1 = ax.bar(ind, menMeans, width, color='r')

womenMeans = (25, 32, 34, 20, 25)
rects2 = ax.bar(ind+width, womenMeans, width, color='y')

# add some text for labels, title and axes ticks
ax.set_ylabel('Scores')
ax.set_title('Scores by group and gender')
ax.set_xticks(ind+width)
ax.set_xticklabels( ('G1', 'G2', 'G3', 'G4', 'G5') )

# append the new patches to the legend-call:

ax.legend( (rects1[0], rects2[0]), ('Men', 'Women'), 
           handler_map={rects1[0]: HandlerSquare(), rects2[0]: HandlerSquare()})

plt.show()

Having defined the class HandlerSquare, one can now apply this to each legend entry as a third argument to the ax.legend call. Note the syntax:

handler_map={rects1[0]: HandlerSquare(), rects2[0]: HandlerSquare()}

The handler_map has to be a dictionary.

This will give you this plot:

enter image description here

Comments