Ivan - 1 year ago 181
Python Question

# Precision and Recall on PySpark DecisionTree model diverges from manual results

I trained a

`DecisionTree`
model on a PySpark dataframe. The resulting dataframe is simulated below:

``````rdd = sc.parallelize(
[
(0., 1.),
(0., 0.),
(0., 0.),
(1., 1.),
(1.,0.),
(1.,0.),
(1.,1.),
(1.,1.)
]
)
df = sqlContext.createDataFrame(rdd, ["prediction", "target_index"])
df.show()
+----------+------------+
|prediction|target_index|
+----------+------------+
|       0.0|         1.0|
|       0.0|         0.0|
|       0.0|         0.0|
|       1.0|         1.0|
|       1.0|         0.0|
|       1.0|         0.0|
|       1.0|         1.0|
|       1.0|         1.0|
+----------+------------+
``````

So let's calculate a metric, recall:

``````metricsp = MulticlassMetrics(df.rdd)
print metricsp.recall()
0.625
``````

Ok. Let's try to confirm that this is correct:

``````tp = df[(df.target_index == 1) & (df.prediction == 1)].count()
tn = df[(df.target_index == 0) & (df.prediction == 0)].count()
fp = df[(df.target_index == 0) & (df.prediction == 1)].count()
fn = df[(df.target_index == 1) & (df.prediction == 0)].count()
print "True Positives:", tp
print "True Negatives:", tn
print "False Positives:", fp
print "False Negatives:", fn
print "Total", df.count()
True Positives: 3
True Negatives: 2
False Positives: 2
False Negatives: 1
Total 8
``````

and calculate recall:

``````r = float(tp)/(tp + fn)
print "recall", r

recall 0.75
``````

and the results differ. What I'm doing wrong?

BTW, all functions from the
`Metrics`
class are giving the same results:

``````print metricsp.recall()
print metricsp.precision()
print metricsp.fMeasure()
0.625
0.625
0.625
``````

The problem is that you are using MultiClassMetrics for processing output of a Binary classifier. From the docs:

``````recall()
Returns recall (equals to precision for multiclass classifier because sum of all false positives is equal to sum of all false negatives)
``````

To get correct results, use recall(label=1):

``````>>> print metricsp.recall(label=1)
0.75
``````

BTW, headers in your `df.show()` seem to be jumbled up, it should be:

``````+----------+------------+
|prediction|target_index|
+----------+------------+
|       0.0|         1.0|
|       0.0|         0.0|
|       0.0|         0.0|
|       1.0|         1.0|
|       1.0|         0.0|
|       1.0|         0.0|
|       1.0|         1.0|
|       1.0|         1.0|
+----------+------------+
``````
Recommended from our users: Dynamic Network Monitoring from WhatsUp Gold from IPSwitch. Free Download