user823743 user823743 - 2 months ago 15
Python Question

How to draw a precision-recall curve with interpolation in python?

I have drawn a precision-recall curve using

sklearn
precision_recall_curve
function and
matplotlib
package. For those of you who are familiar with precision-recall curve you know that some scientific communities only accept it when its interpolated, similar to this example here. Now my question is if any of you know how to do the interpolation in python? I have been searching for a solution for a while now but with no success! Any help would be greatly appreciated.

Solution: Both solutions by @francis and @ali_m are correct and together solved my problem. So, assuming that you get an output from the
precision_recall_curve
function in
sklearn
, here is what I did to plot the graph:

precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),scores.ravel())
pr = copy.deepcopy(precision[0])
rec = copy.deepcopy(recall[0])
prInv = np.fliplr([pr])[0]
recInv = np.fliplr([rec])[0]
j = rec.shape[0]-2
while j>=0:
if prInv[j+1]>prInv[j]:
prInv[j]=prInv[j+1]
j=j-1
decreasing_max_precision = np.maximum.accumulate(prInv[::-1])[::-1]
plt.plot(recInv, decreasing_max_precision, marker= markers[mcounter], label=methodNames[countOfMethods]+': AUC={0:0.2f}'.format(average_precision[0]))


And these lines will plot the interpolated curves if you put them in a for loop and pass it the data of each method at each iteration. Note that this will not plot the non-interpolated precision-recall curves.

Answer

A backward iteration can be performed to remove the increasing parts in precision. Then, vertical and horizontal lines can be plotted as specified in the answer of Bennett Brown to vertical & horizontal lines in matplotlib .

Here is a sample code:

import numpy as np
import matplotlib.pyplot as plt

#just a dummy sample
recall=np.linspace(0.0,1.0,num=42)
precision=np.random.rand(42)*(1.-recall)
precision2=precision.copy()
i=recall.shape[0]-2

# interpolation...
while i>=0:
    if precision[i+1]>precision[i]:
        precision[i]=precision[i+1]
    i=i-1

# plotting...
fig, ax = plt.subplots()
for i in range(recall.shape[0]-1):
    ax.plot((recall[i],recall[i]),(precision[i],precision[i+1]),'k-',label='',color='red') #vertical
    ax.plot((recall[i],recall[i+1]),(precision[i+1],precision[i+1]),'k-',label='',color='red') #horizontal

ax.plot(recall,precision2,'k--',color='blue')
#ax.legend()
ax.set_xlabel("recall")
ax.set_ylabel("precision")
plt.savefig('fig.jpg')
fig.show()

And here is a result:

enter image description here