Anupam Sobti Anupam Sobti - 10 days ago 10
Python Question

Python OpenCv 3.1: knn not working with the example in docs

I was trying to get the example on the docs page working (http://docs.opencv.org/3.1.0/d5/d26/tutorial_py_knn_understanding.html)

cv2.KNearest() has been replaced with cv2.ml.KNearest_create()

However, the following code snippet still results in an error.

#!/usr/bin/python3
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Feature set containing (x,y) values of 25 known/training data
trainData = np.random.randint(0,100,(25,2)).astype(np.float32)

# Labels each one either Red or Blue with numbers 0 and 1
responses = np.random.randint(0,2,(25,1)).astype(np.float32)

# Take Red families and plot them
red = trainData[responses.ravel()==0]
plt.scatter(red[:,0],red[:,1],80,'r','^')

# Take Blue families and plot them
blue = trainData[responses.ravel()==1]
plt.scatter(blue[:,0],blue[:,1],80,'b','s')

newcomer = np.random.randint(0,100,(1,2)).astype(np.float32)
plt.scatter(newcomer[:,0],newcomer[:,1],80,'g','o')

knn = cv2.ml.KNearest_create()
knn.train(trainData,responses)
ret, results, neighbours ,dist = knn.find_nearest(newcomer, 3)

print ("result: ", results,"\n")
print ("neighbours: ", neighbours,"\n")
print ("distance: ", dist)
plt.show()


I get the following error on execution:

Traceback (most recent call last):
File "./knn_test.py", line 24, in <module>
knn.train(trainData,responses)
TypeError: only length-1 arrays can be converted to Python scalars


The knn.train() function expects

>>> knn.train.__doc__
'train(trainData[, flags]) -> retval or train(samples, layout, responses) -> retval'


I couldn't find an example of the layout definition. What is the change required in order to get it working?
Thanks in advance!

s1h s1h
Answer

The KNN classifier is derived from the

StatModel

base class.

The

layout

specifier is an integer which tells the model if a single sample occupies one row or one column (see StatModel::train and ml::SampleTypes).

Since you've got 25 rows of samples, you'll need to pass

cv2.ml.ROW_SAMPLE