TypeError:Opencv ANN.train的参数'%s'的预期Ptr <cv :: UMat>

时间:2019-12-21 10:50:44

标签: python opencv machine-learning raspberry-pi computer-vision

我正在尝试使用OpenCV上的ANN对不同的颜色桶进行分类。这是我的程序:

import cv2
import numpy as np
import pickle

def load_data():
    db = open('/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/green_barrel.pkl','rb')
    training_data = pickle.load(db)
    db.close()
    return training_data    

def wrap_data():
    tr_d = load_data()
    tr_d = np.array(tr_d)
    training_inputs = [np.reshape(x,(1,1,-1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs,training_results)
    return training_data

def vectorized_result(j):
    e = np.zeros((2,1),dtype=np.uint8)
    e[True] = 1.0
    return e

def create_ANN(hidden=50):
    ann = cv2.ml.ANN_MLP_create()
    ann.setLayerSizes(np.array([2500,hidden,2]))
    ann.setTrainMethod(cv2.ml.ANN_MLP_RPROP)
    ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
    ann.setTermCriteria((cv2.TERM_CRITERIA_EPS|cv2.TERM_CRITERIA_COUNT,20,1))
    return ann

def train(ann,samples,epochs=2):
    tr = wrap_data()

    for i in range(epochs):
        counter = 0
        for img in tr:
            if (counter>samples):
                break
            if (counter%1000==0):
                print('Epoch %d: Trained %d/%d'%(0,counter,samples))
            counter+=1
            data = np.array([img])
            data = data.ravel()
            print(data)

            ann.train(np.array([data]),cv2.ml.ROW_SAMPLE,np.array([data]))
    print('Epoch %d complete'%i)
    return ann

def predict(ann,samples):
    resized = sample.copy()
    rows,cols = resized.shape
    if (rows!=50 or cols !=50) and rows*cols>0:
        resized = cv2.resize(resized,(50,50),interpolation=cv2.INTER_CUBIC)
    return ann.predict(np.array([resized.ravel()],dtype=np.float32))

ann = train(create_ANN(56),50)

但是我得到这个错误:

Traceback (most recent call last):
  File "/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/__init__.py", line 59, in <module>
    ann = train(create_ANN(56),50)
  File "/home/pi/CERCBOT-PiWars-2020-master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/__init__.py", line 48, in train
    ann.train(np.array([data]),cv2.ml.ROW_SAMPLE,np.array([data]))
TypeError: Expected Ptr<cv::UMat> for argument '%s'

当我打印data时,我得到

[array([[[56, 62, 58]]], dtype=uint8)
 array([[1],
       [1]], dtype=uint8)]

当我尝试更改数据的dtype时,出现ValueError。 您可以从https://github.com/Rishan123/CERCBOT-PiWars-2020/blob/master/Eco-Disaster/OpenCV/Object_Detection/barrel_detector/green_barrel.pkl?raw=true

获取green_barrel.pkl

0 个答案:

没有答案