我在plt.imshow
遇到以下错误
TypeError: Image data cannot be converted to float
对于此代码:
import keras
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
def preprocess(x):
x = tf.image.per_image_standardization(x)
return x
train_images = preprocess(train_images)
test_images = preprocess(test_images)
plt.figure()
plt.imshow(train_images[1])
plt.colorbar()
plt.grid(False)
plt.show()
有什么想法为什么会这样?谢谢!
答案 0 :(得分:1)
在您的脚本中,train_images
不包含实际数据,而只是占位符张量:
train_images[1]
<tf.Tensor 'strided_slice_2:0' shape=(28, 28) dtype=float32>
最简单的解决方案是在脚本顶部启用急切执行:
tf.enable_eager_execution()
这意味着在运行时,张量实际上将包含您要绘制的数据:
train_images[1]
<tf.Tensor: id=95, shape=(28, 28), dtype=float32, numpy=
array([[-0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 ,
-0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 ,
-0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 ,
-0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 ,
-0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 , -0.4250042 ,
-0.4250042 , -0.4250042 , -0.4250042 ], # etc
应该解决您的错误。您可以阅读有关在TF的website上渴望执行的更多信息。
或者,您也可以通过在会话中实际评估图像张量来绘制图表:
with tf.Session() as sess:
img = sess.run(train_images[1])
plt.figure()
plt.imshow(img)
plt.colorbar()
plt.grid(False)
plt.show()