我花了一整天时间,并且不知道我做错了什么。请帮忙。我使用以下代码创建了包含一些图像的TFRecords文件:
def convert_to_TF(images, labels, name):
label_count = labels.shape[0]
print('There are %d images in this dataset.' % (label_count))
if images.shape[0] != label_count:
raise ValueError('WTF! Devil! There are %d images and %d labels. Go fix yourself!' %
(images.shape[0], label_count))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(label_count):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
然后我尝试用以下内容读取保存的TFRecords文件:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue, 'train')
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
return image, label, height, width, depth
然后我得到以下错误。
Traceback (most recent call last):
File "/media/mcamp/Local SSHD/Python Projects/Garage Door Project/FreshStart/TFCode2.py", line 50, in <module>
label, image = read_and_decode(filename)
File "/media/mcamp/Local SSHD/Python Projects/Garage Door Project/FreshStart/TFCode2.py", line 31, in read_and_decode
_, serialized_example = reader.read(filename_queue, 'train')
File "/home/mcamp/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/io_ops.py", line 264, in read
queue_ref = queue.queue_ref
AttributeError: 'str' object has no attribute 'queue_ref'
答案 0 :(得分:6)
filename = "garage_door100_TRAIN.tfrecords"
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=1)
label, image = read_and_decode(filename_queue)
这就是我所遗漏的......