Frido Frido - 7 months ago 63
Python Question

Getting Error while giving input to SVM in opencv

I am trying to classify images to red and green using SVM.For training I extracted rgba from training images.I have also converted my list to numpy array.But i am getting Error while i give it to SVM for training.My sample code is

import cv2
import numpy
import numpy as np
from PIL import Image
import os
print "OpenCV version : {0}".format(cv2.__version__)
svm_params = dict( kernel_type = cv2.SVM_LINEAR,
svm_type = cv2.SVM_C_SVC,
C=2.67, gamma=5.383 )

path1='c:\\colors\\red\\'
path2='c:\\colors\\green\\'
training_set = []
test_set=[]
training_labels=[]
rlist = os.listdir(path1)
glist= os.listdir(path2)
for file in rlist:
img = Image.open(path1 + file)
img200=img.resize((100,100)).convert('RGBA')
arr= np.array(img200)
print arr
training_set.append(arr)
training_labels.append(1)
for file in glist:
img = Image.open(path2 + file)
img200=img.resize((100,100)).convert('RGBA')
arr= np.array(img200)
training_set.append(arr)
training_labels.append(2)
###### SVM training ########################
trainData=np.float32(training_set)
responses=np.float32(training_labels)
svm = cv2.SVM()
svm.train(trainData,responses, params=svm_params)
svm.save('trycolor_svm_data.dat')


I am getting Error as

cv2.error: ..\..\..\..\opencv\modules\ml\src\inner_functions.cpp:857: error: (-5) train data must be floating-point matrix in function cvCheckTrainData


How can i correctly give input to svm

Answer

If you print 'arr' you will understand that it is a list of lists.Thats the problem.You need to flatten that before giving input to svm.

flat_arr= arr.ravel()

Here is the corrected code.

for file in listing1:
 img = Image.open(path1 + file)
 img200=img.resize((100,100)).convert('RGBA')
 arr= np.array(img200)
 flat_arr= arr.ravel()
 training_set.append(flat_arr)
 training_labels.append(1)
for file in listing2:
 img = Image.open(path2 + file)
 img200=img.resize((100,100)).convert('RGBA')
 arr= np.array(img200)
 flat_arr= arr.ravel()
 training_set.append(flat_arr)
 training_labels.append(2)
Comments