我是TensorFlow的新手。我正在尝试从github- https://github.com/Curt-Park/handwritten_digit_recognition运行用于数字识别'wide_resnet_28_10'的预训练NN。当我尝试预测图像时,它说期望输入具有4D。这就是我尝试过的-
from tensorflow.keras.models import load_model
import tensorflow as tf
import cv2
import numpy
model = load_model(r'C:\Users\sesha\Desktop\python\Deep learning NN\handwritten_digit_recognition-master\models\WideResNet28_10.h5')
image = cv2.imread(r'C:\Users\sesha\Desktop\python\Deep learning NN\test_org01.png')
img = tf.convert_to_tensor(image)
predictions = model.predict([img])
print(np.argmax(predictions))
大多数教程都含糊不清,我确实尝试了np.reshape(1,X,X,-1)无效。
答案 0 :(得分:2)
对于4D输入,它需要成批的数据。您可以通过以下操作将其设为4D张量:
predictions = model.predict(tf.expand_dims(img, 0))
如果这不起作用,请尝试使用预报而不是预测。
也: 我认为您的图像阅读不正确。它可能会给你一个字节串的张量。
这应该有效
path = tf.constant(img_path)
image = tf.io.read_file(path)
image = tf.io.decode_image(image)
image = tf.image.resize(image, (X, Y)) # if necessary