无法为形状为((1,299,299,3)''的张量'import / Mul:0'输入形状()的值

时间:2018-07-14 17:32:21

标签: python tensorflow image-recognition

我正在尝试从tensorflow中的'test_images'目录中分类多个图像,并且出现以下错误:“无法为Tensor'import / Mul:0'输入shape()的值,其形状为'(1,299 ,299,3)'“。我还使用输入队列来输入一批要分类的图像,由于某种原因,这不起作用。任何帮助将不胜感激!

完整代码:

class image_recognition_algorithm():

def __init__(self, image_file, model_file, label_file):
    self.model_file = model_file
    self.label_file = label_file
    self.image_file = image_file


def load_graph(self):
    graph = tf.Graph()
    graph_def = tf.GraphDef()

    with open(model_file, "rb") as f:
        graph_def.ParseFromString(f.read())
    with graph.as_default():
        tf.import_graph_def(graph_def)

    return graph

def read_images_from_file(self, image_file, input_height=299, input_width=299,
                input_mean=128, input_std=128):
    input_queue = "file_contents"
    output_name = "normalized"
    res_list = []
    for i in res_list:
        file_names = res_list.append(join('test_images', i))
        input_queue = tf.gfile.FastGFile(file_names, 'rb').read()
        file_contents = tf.read_file(input_queue)
        image_reader = tf.image.decode_jpeg(file_contents, channels=3)
        float_caster = tf.expand_dims(image_reader, tf.float32)
        dims_expander = tf.expand_dims(float_caster, 0);
        resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
        normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
        images = tf.convert_to_tensor(file_contents)
        input_queue = tf.train.slice_input_producer([images], shuffle=True)
        image = read_images_from_file(input_queue)
        image = preprocess_image(image)

        image_file = tf.train.batch([image], batch_size=5)

        sess = tf.Session()
        result = sess.run(image_file)

        return result

def load_labels(self, label_file):
    label = []
    proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines()
    for l in proto_as_ascii_lines:
        label.append(l.rstrip())
    return label

def main(self, image_file):
    self.model_file = "tf_files/retrained_graph.pb"
    self.label_file = "tf_files/retrained_labels.txt"
    input_height = 299
    input_width = 299
    input_mean = 128
    input_std = 128
    input_layer = "Mul"
    output_layer = "final_result"

    graph = self.load_graph()

    image_file = [f for f in listdir('test_images') if isfile(join('test_images', f))]

    for image in image_file:
        t = self.read_images_from_file(image, input_height = input_height,
                                           input_width = input_width,
                                            input_mean = input_mean,
                                             input_std = input_std)

        input_name = "import/" + input_layer
        output_name = "import/" + output_layer
        input_operation = graph.get_operation_by_name(input_name);
        output_operation = graph.get_operation_by_name(output_name);
        config = tf.ConfigProto(device_count={"CPU": 4},
                                inter_op_parallelism_threads=1,
                                intra_op_parallelism_threads=4)
        self.sess = tf.Session(graph=graph, config=config)
        start = time.time()
        results = self.sess.run(output_operation.outputs[0],
                          {input_operation.outputs[0]: t})
        end=time.time()
        results = np.squeeze(results)

        top_k = results.argsort()[-5:][::-1]
        labels = load_labels(label_file)

        print('\nEvaluation time (1-image): {:.3f}s\n'.format(end-start))

        for i in top_k:
            print(image_file, labels[i], results[i])

        return [image_file] + list(results)

    res_list = [f for f in listdir('test_images') if isfile(join('test_images', f))]

    for image in res_list:
        if image.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
            res_list.append(join('test_images', image))


if __name__ == '__main__':
         model_file = "tf_files/retrained_graph.pb"
         label_file = "tf_files/retrained_labels.txt"
         image_file = [f for f in listdir('test_images') if 
         isfile(join('test_images', f))]
         image_recognition_algorithm_obj = 
         image_recognition_algorithm(model_file, label_file, image_file)
         image_recognition_algorithm_obj.main(image_file)

0 个答案:

没有答案