我在TensorFlow中训练神经网络,目前将PNG作为输入。虽然这很好用,但我想转而使用HDF5格式输入以获得更大的灵活性,但我在修改代码方面遇到了问题。在这两种情况下,文本文件都包含要用作训练示例的所有PNG / H5文件的路径列表,每行一个示例。我的PNG当前代码的简化版本如下:
class ImageReader(object):
def __init__(self, image_list, coord):
'''
image_list: string array of training image paths.
coord: TensorFlow queue coordinator.
'''
self.image_list = image_list
self.coord = coord
self.images = tf.convert_to_tensor(self.image_list, dtype=tf.string)
self.queue = tf.train.slice_input_producer([self.images])
self.image = read_images_from_disk(self.queue)
def dequeue(self, batch_size):
image_batch = tf.train.batch([self.image], batch_size)
return image_batch
def read_images_from_disk(input_queue):
img_contents = tf.read_file(input_queue[0])
img = tf.image.decode_png(img_contents, channels=1, dtype=tf.uint16)
return img
我希望能够做的是替换read_images_from_disk
方法来处理HDF5文件,特别是这些内容:
def read_images_from_disk(input_queue):
h5_contents=h5py.File(input_queue[0],'r')
img=h5_contents['image']
return img
显然,这并不起作用,因为h5py想要一个字符串作为输入而不是Tensor,我得到错误:AttributeError: 'Tensor' object has no attribute 'encode'
我需要修改什么才能使其正常工作?