在加载图像时,我试图通过在pyplot中将其打印出来,以确保正确加载了图像,但是我遇到了问题。如何将这些图像加载到Tensorflow中,并使用pyplot的imshow()
(或其他方式)进行检查?
图像数据是单通道(黑白)jpeg。最初将其加载为具有未知形状和uint8 dtype的Tensor。我已经尝试确保将Tensor重塑为正确的形状并将其铸造为float32。我还尝试确保将值从0.0-1.0缩放为浮点数,并使用imshow()
函数中的Gray映射。
import tensorflow as tf
import matplotlib.pyplot as plt
def load_and_preprocess_jpeg(imagepath):
img = tf.read_file(imagepath)
img_tensor = tf.image.decode_jpeg(img)
img_tensor.set_shape([792,1224,1])
img_tensor = tf.reshape(img_tensor, [792,1224])
img_tensor = tf.cast(img_tensor, tf.float32, name='ImageCast')
#img_tensor /= 255.0 #Tried with and without
return img_tensor
def read_data(all_filenames):
path_Dataset = tf.data.Dataset.from_tensor_slices(all_filenames)
image_Dataset = path_Dataset.map(load_and_preprocess_jpeg)
plt.figure(figsize=(8,8))
temp_DS = image_Dataset.take(4)
itera = temp_DS.make_one_shot_iterator()
for n in range(4):
image = itera.get_next()
plt.subplot(2,2,n+1)
plt.imshow(image)
plt.grid(False)
plt.xticks([])
plt.yticks([])
我的堆栈跟踪:
File "<stdin>", line 1, in <module>
line 34, in read_data
plt.imshow(image)
matplotlib\pyplot.py, line 3205, in imshow
**kwargs)
matplotlib\__init__.py, line 1855, in inner
return func(ax, *args, **kwargs)
matplotlib\axes\_axes.py, line 5487, in imshow
im.set_data(X)
matplotlib\image.py, line 649, in set_data
raise TypeError("Image data cannot be converted to float")
答案 0 :(得分:1)
您正在尝试绘制张量。为了绘制图像,您必须首先运行会话。尝试以下代码:
import tensorflow as tf
import matplotlib.pyplot as plt
def load_and_preprocess_jpeg(imagepath):
img = tf.read_file(imagepath)
img_tensor = tf.image.decode_jpeg(img)
img_tensor = tf.image.resize_images(img_tensor, [img_size,img_size])
img_tensor = tf.cast(img_tensor, tf.float32, name='ImageCast')
img_tensor /= 255.0
return img_tensor
path_Dataset = tf.data.Dataset.from_tensor_slices(all_filenames)
image_Dataset = path_Dataset.map(load_and_preprocess_jpeg)
temp_DS = image_Dataset.take(4)
itera = temp_DS.make_one_shot_iterator()
image = itera.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while True:
try:
image_to_plot = sess.run(image)
plt.figure(figsize=(8,8))
plt.subplot(2,2,n+1)
plt.imshow(image_to_plot)
plt.grid(False)
plt.xticks([])
plt.yticks([])
except tf.errors.OutOfRangeError:
break