我正在尝试使用OpenCV上的ANN对不同的颜色桶进行分类。这是我的程序:
import cv2
import numpy as np
import pickle
def load_data():
db = open('/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/green_barrel.pkl','rb')
training_data = pickle.load(db)
db.close()
return training_data
def wrap_data():
tr_d = load_data()
tr_d = np.array(tr_d)
training_inputs = [np.reshape(x,(1,1,-1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs,training_results)
return training_data
def vectorized_result(j):
e = np.zeros((2,1),dtype=np.uint8)
e[True] = 1.0
return e
def create_ANN(hidden=50):
ann = cv2.ml.ANN_MLP_create()
ann.setLayerSizes(np.array([2500,hidden,2]))
ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
ann.setTermCriteria((cv2.TERM_CRITERIA_EPS|cv2.TERM_CRITERIA_COUNT,20,1))
return ann
def train(ann,samples,epochs=2):
tr = wrap_data()
for i in range(epochs):
counter = 0
for img in tr:
if (counter>samples):
break
if (counter%1000==0):
print('Epoch %d: Trained %d/%d'%(0,counter,samples))
counter+=1
data = np.array([img])
data = data.ravel()
print(data)
ann.train(np.array([data]),cv2.ml.ROW_SAMPLE,np.array([data]))
print('Epoch %d complete'%i)
return ann
def predict(ann,samples):
resized = sample.copy()
rows,cols = resized.shape
if (rows!=50 or cols !=50) and rows*cols>0:
resized = cv2.resize(resized,(50,50),interpolation=cv2.INTER_CUBIC)
return ann.predict(np.array([resized.ravel()],dtype=np.float32))
ann = train(create_ANN(56),50)
但是我得到这个错误:
Traceback (most recent call last):
File "/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/__init__.py", line 59, in <module>
ann = train(create_ANN(56),50)
File "/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/__init__.py", line 48, in train
ann.train(np.array([data]),cv2.ml.ROW_SAMPLE,np.array([data]))
TypeError: Expected Ptr<cv::UMat> for argument '%s'
当我打印data
时,我得到
[array([[[56, 62, 58]]], dtype=uint8)
array([[1],
[1]], dtype=uint8)]
当我尝试更改数据的dtype时,出现ValueError。 您可以从https://github.com/Rishan123/CERCBOT-PiWars-2020/blob/master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/green_barrel.pkl?raw=true
获取green_barrel.pkl