Alex - 1 year ago 239

Python Question

This is a follow-up question from How to know what classes are represented in return array from predict_proba in Scikit-learn

In that question, I quoted the following code:

`>>> import sklearn`

>>> sklearn.__version__

'0.13.1'

>>> from sklearn import svm

>>> model = svm.SVC(probability=True)

>>> X = [[1,2,3], [2,3,4]] # feature vectors

>>> Y = ['apple', 'orange'] # classes

>>> model.fit(X, Y)

>>> model.predict_proba([1,2,3])

array([[ 0.39097541, 0.60902459]])

I discovered in that question this result represents the probability of the point belonging to each class, in the order given by model.classes_

`>>> zip(model.classes_, model.predict_proba([1,2,3])[0])`

[('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]

So... this answer, if interpreted correctly, says that the point is probably an 'orange' (with a fairly low confidence, due to the tiny amount of data). But intuitively, this result is obviously incorrect, since the point given was identical to the training data for 'apple'. Just to be sure, I tested the reverse as well:

`>>> zip(model.classes_, model.predict_proba([2,3,4])[0])`

[('apple', 0.60705475211840931), ('orange', 0.39294524788159074)]

Again, obviously incorrect, but in the other direction.

Finally, I tried it with points that were much further away.

`>>> X = [[1,1,1], [20,20,20]] # feature vectors`

>>> model.fit(X, Y)

>>> zip(model.classes_, model.predict_proba([1,1,1])[0])

[('apple', 0.33333332048410247), ('orange', 0.66666667951589786)]

Again, the model predicts the wrong probabilities. BUT, the model.predict function gets it right!

`>>> model.predict([1,1,1])[0]`

'apple'

Now, I remember reading something in the docs about predict_proba being inaccurate for small datasets, though I can't seem to find it again. Is this the expected behaviour, or am I doing something wrong? If this IS the expected behaviour, then why does the predict and predict_proba function disagree one the output? And importantly, how big does the dataset need to be before I can trust the results from predict_proba?

Ok, so I did some more 'experiments' into this: the behaviour of predict_proba is heavily dependent on 'n', but not in any predictable way!

`>>> def train_test(n):`

... X = [[1,2,3], [2,3,4]] * n

... Y = ['apple', 'orange'] * n

... model.fit(X, Y)

... print "n =", n, zip(model.classes_, model.predict_proba([1,2,3])[0])

...

>>> train_test(1)

n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]

>>> for n in range(1,10):

... train_test(n)

...

n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]

n = 2 [('apple', 0.98437355278112448), ('orange', 0.015626447218875527)]

n = 3 [('apple', 0.90235408180319321), ('orange', 0.097645918196806694)]

n = 4 [('apple', 0.83333299908143665), ('orange', 0.16666700091856332)]

n = 5 [('apple', 0.85714254878984497), ('orange', 0.14285745121015511)]

n = 6 [('apple', 0.87499969631893626), ('orange', 0.1250003036810636)]

n = 7 [('apple', 0.88888844127886335), ('orange', 0.11111155872113669)]

n = 8 [('apple', 0.89999988018127364), ('orange', 0.10000011981872642)]

n = 9 [('apple', 0.90909082368682159), ('orange', 0.090909176313178491)]

How should I use this function safely in my code? At the very least, is there any value of n for which it will be guaranteed to agree with the result of model.predict?

Recommended for you: Get network issues from **WhatsUp Gold**. **Not end users.**

Answer Source

if you use `svm.LinearSVC()`

as estimator, and `.decision_function()`

(which is like svm.SVC's .predict_proba()) for sorting the results from most probable class to the least probable one. this agrees with `.predict()`

function. Plus, this estimator is faster and gives almost the same results with `svm.SVC()`

the only drawback for you might be that `.decision_function()`

gives a signed value sth like between -1 and 3 instead of a probability value. but it agrees with the prediction.

Recommended from our users: **Dynamic Network Monitoring from WhatsUp Gold from IPSwitch**. ** Free Download**