我尝试使用tf.Data加载图像,但出现错误。这是我的代码:
import cv2
import tensorflow as tf
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
image_decoded = tf.expand_dims(image_decoded, dim=0)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
image_resized = tf.expand_dims(image_resized, dim=0)
return image_resized, label
filenames = ["data/img.jpeg", "data/img.jpeg"]
labels = [0, 37]
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(filenames), tf.constant(labels)))
dataset = (dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype]))))
dataset = dataset.map(_resize_function)
dataset = dataset.batch(2)
dataset = dataset.prefetch(2)
# iterator = dataset.make_initializable_iterator()
configProt = tf.ConfigProto()
configProt.gpu_options.allow_growth = True
configProt.allow_soft_placement = True
sess = tf.Session(config = configProt)
iterator = dataset.make_one_shot_iterator()
# next_element = iterator.get_next()
images, labels = iterator.get_next()
print(sess.run(labels))
但是,我得到的是
tensorflow.python.framework.errors_impl.UnimplementedError: Unsupported object type Tensor
[[Node: PyFunc = PyFunc[Tin=[DT_STRING, DT_INT32], Tout=[DT_UINT8, DT_INT32], token="pyfunc_0"](arg0, arg1)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,1,28,28,?], <unknown>], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
我无法使用tf1.8运行它。有什么问题?
答案 0 :(得分:0)
我通过删除_read_py_function中的tf.xxx解决了该问题。