Tensorflow预测错误,invalidArgumentError:断言失败:[无法将字节解码为JPEG,PNG,GIF或BMP]

时间:2019-03-13 18:45:51

标签: python tensorflow base64 prediction

我使用Google对象检测Api训练了Tensorflow Ssd对象检测模型,并使用提供的“ export_inference_graph.py”脚本作为“ Saved_model.pb”文件(以“ encoded_image_string_tensor”作为输入类型)导出了训练后的模型。我试图对模型进行预测,但出现以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [Unable to decode bytes as JPEG, PNG, GIF, or BMP]

İ将模型加载到图形中,如下所示:

with tf.Session() as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], saved_model_file)
    graph = tf.get_default_graph()

做出如下预测:

# Convert the image into base64 encoded string
img = Image.open(IMAGE_PATH)    
resized_img = img.resize((300, 300), Image.ANTIALIAS)
binary_io = io.BytesIO()
resized_img.save(binary_io, "JPEG")

bytes_string_image = base64.b64encode(binary_io.getvalue()).decode("utf-8")

# Define the input and output placeholder tensors
input_tensor = graph.get_tensor_by_name('encoded_image_string_tensor:0')
tensor_dict = {}
for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes']:
        tensor_name = key + ':0'
        tensor_dict[key] = graph.get_tensor_by_name(tensor_name)

# Finally, do the prediciton
output_dict = sess.run(tensor_dict, feed_dict={
                           input_tensor: bytes_string_image})

2 个答案:

答案 0 :(得分:0)

似乎在创建TFRecords时,仅支持jpeg图像,而在文档中没有指出!同样,当您尝试使用其他类型时,它不会发出任何警告,也不会出现任何异常,因此像我这样的人会浪费大量的时间来调试一些很容易发现并首先解决的问题。 无论如何,将所有图像转换为jpg解决了这个怪异的难题。

您还可以检查此问题: https://github.com/tensorflow/tensorflow/issues/13044

此程序将挑选出实际上不是jpeg的文件。删除它们,然后就可以了。 “” 导入imghdr
导入cv2
导入操作系统
导入glob

对于['train','test']中的文件夹:
image_path = os.path.join(os.getcwd(),('images /'+文件夹))
打印(图像路径)
用于glob.glob(image_path +'/*.jpg')中的文件:
图片= cv2.imread(文件)
file_type = imghdr.what(文件)
如果file_type!='jpeg':
print(文件+“-无效-” + str(文件类型))

cv2.imwrite(文件,图像)

“”“

答案 1 :(得分:0)

我的问题是我将图像保存为字节,但是它必须是字符串。

所以代替这个:

encoded = image.tobytes()
features = {
    'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded])),
    ...
}

您需要执行以下操作:

encoded = cv2.imencode('.jpg', image)[1].tostring()
features = {
    'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded])),
    ...
}