如何将图像ROI注入Tensorflow的session.run()?

时间:2019-03-08 14:40:33

标签: python dictionary tensorflow

我正在尝试将图像投资回报率输入到我从here获取的Tensorflow分类器中。这个想法是首先运行一个简单的过滤器,获取矩形候选对象,然后检查(使用网络)每个矩形(roi)是否确实是我要寻找的。

class ScrewDetector:
    def __init__(self):
        self.session = None # an internal variable needed for inception network

        # to keep the screw data in
        self.screw_data = dict()

        # load the labels of the classification: screw / non-screw
        self.class_labels = [line.rstrip() for line in tf.gfile.GFile(home + "/imagine_weights/screw_detector/retrained_labels.txt")]

        # prepare the network
        with tf.gfile.FastGFile(home + "/weights/screw_detector/retrained_graph.pb", 'rb') as f:
            graph_def = tf.GraphDef()   ##  the graph-graph_def is a saved copy of a TensorFlow graph, object initialization
            graph_def.ParseFromString(f.read()) # parse serialized protocol buffer data into variable
            _ = tf.import_graph_def(graph_def, name='') # import a serialized TensorFlow GraphDef protocol buffer, extract objects in the GraphDef as tf.Tensor

        # start the session
        with tf.Session() as self.session:
            self.softmax_tensor = self.session.graph.get_tensor_by_name('final_result:0')


    def detect_screw(self):

        # get a copy and resize it
        img_raw = self.cv_image.copy()
        resized_img = cv2.resize(img_raw, (0,0), fx=RESIZE_FACTOR, fy=RESIZE_FACTOR)

        # grayscale it
        gray = cv2.cvtColor(resized_img, cv2.COLOR_BGR2GRAY)

        # detect circles in the image
        circles = cv2.HoughCircles(gray, cv2.HOUGH_GRADIENT, 1, 100, param1=50,param2=35,minRadius=15,maxRadius=30)

        # ensure at least some circles were found
        if circles is not None:
            # convert the (x, y) coordinates and radius of the circles to integers
            circles = np.round(circles[0, :]).astype("int")

            # get a counter
            screw_id = 0
            # loop over the (x, y) coordinates and radius of the circles
            for (x, y, r) in circles:
                # draw the circle in the output image, then draw a rectangle corresponding to the center of the circle
                #cv2.circle(resized_img, (x, y), r, (0, 255, 0), 4)
                cv2.rectangle(resized_img, (x - r, y - r), (x + r, y + r), (0, 0, 255), 5)

                # get the above rectangle as ROI
                screw_roi = resized_img[y:y+r, x:x+r]

                # feed it into the network
                #import IPython; IPython.embed()
                predictions = self.session.run(self.softmax_tensor, feed_dict={screw_id: [screw_roi.flatten()]})

                # get prediction values in array back
                top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

                # output
                for node_id in top_k:
                    human_string = self.class_labels[node_id]
                    score = predictions[0][node_id]
                    print('%s (score = %.5f)' % (human_string, score))

                    # if it is a screw, go on, save its coordinates and append into the network

                    # remap in the original image
                    scaled_point = (round(x * (1/RESIZE_FACTOR)), round(y * (1/RESIZE_FACTOR)))

                    # append to the dict
                    self.screw_data[scaled_point] = r * RESIZE_FACTOR

                    # iterate the counter
                    screw_id += screw_id

            #  publish the result, which is an image (scaled) 
            result_image_msg = Image()
            try:
                result_image_msg = self.bridge.cv2_to_imgmsg(resized_img, "bgr8") 
                #print(self.screw_data)
            except CvBridgeError as e:
                print("Could not make it through the cv bridge of death.")

            self.result_image_pub.publish(result_image_msg)
        else:
            print("No detection of circles.")

但是我得到了

TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a int into a Tensor.

我确实知道变量screw_idscrew_roi不为空。而且我确实知道需要输入字典,这就是为什么我首先尝试这样做。但由于上述原因,我无法使其运行。

有什么想法吗?

编辑:通常,此代码将加载图像并进行预测,如下所示:

image_data = tf.gfile.FastGFile(image_path, 'rb').read()
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})

我所要做的就是将其转换为一种形式,以在操作过程中提供的图像ROI进行操作。不可能太复杂。

1 个答案:

答案 0 :(得分:0)

feed_dict期望以张量为键的字典,以指定值填充占位符。 screw_id的启动方式不在您的代码段中,但是我敢打赌,它不是任何类型的张量,因此是您的错误。