我尝试写一个tfrecords并阅读它,我已经获得了'train.tfrecords'文件,但是当我使用这个函数来阅读它时
image_size=224
def read_and_decode(filename,batch_size):
filename_queue=tf.train.string_input_producer([filename])
reader=tf.TFRecordReader()
_,serialized_example=reader.read(filename_queue)#返回文件名和文件
features=tf.parse_single_example(serialized_example,features={
"label": tf.FixedLenFeature([],tf.int64),
"img_raw":tf.FixedLenFeature([],tf.string),
})
img=tf.decode_raw(features['img_raw'],tf.uint8)
img=tf.cast(img,tf.float32)
img=tf.reshape(img,[image_size,image_size,3])
img = tf.random_crop(img, [image_size, image_size, 3])
img = tf.image.random_flip_left_right(img)
img=tf.image.per_image_standardization(img)
label=tf.cast(features['label'],tf.int32)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=batch_size, num_threads=10, capacity=16 * batch_size,
min_after_dequeue=8*batch_size)
label_batch= tf.reshape(label_batch, [batch_size, 1])
indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
label_batch = tf.sparse_to_dense(
tf.concat(values=[indices, label_batch], axis=1),
[batch_size, 3], 1.0, 0.0)
assert len(img_batch.get_shape()) == 4
assert img_batch.get_shape()[0] == batch_size
assert img_batch.get_shape()[-1] == 3
assert len(label_batch.get_shape()) == 2
assert label_batch.get_shape()[0] == batch_size
assert label_batch.get_shape()[1] == 3
# Display the training images in the visualizer.
tf.summary.image('images', img_batch)
return img_batch, label_batch
错误是:
Caused by op 'Reshape', defined at:
File "C:/Users/Administrator/Desktop/tensorflow/ResNet/main.py", line 176, in <module>tf.app.run()
File"C:\ProgramFiles\Python35\lib\sitepackages\tensorflow\python\platform\app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "C:/Users/Administrator/Desktop/tensorflow/ResNet/main.py", line 169, in main
train(hps)
File "C:/Users/Administrator/Desktop/tensorflow/ResNet/main.py", line 31, in train
images, labels = read_and_decode('train.tfrecords', hps.batch_size)
File "C:\Users\Administrator\Desktop\tensorflow\ResNet\input.py", line 15, in read_and_decode
img=tf.reshape(img,[image_size,image_size,3])
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 2510, in reshape
name=name)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 768, in apply_op
op_def=op_def)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\framework\ops.py", line 2336, in create_op
original_op=self._default_original_op, op_def=op_def)
File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\framework\ops.py", line 1228, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 154587 values, but the requested shape has 150528
[[Node: Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](Cast, Reshape/shape)]]
我在互联网上搜索了一段时间,我认为问题发生在这些代码上,但我不知道如何修复它,请帮帮我,谢谢
img=tf.decode_raw(features['img_raw'],tf.uint8)
img=tf.cast(img,tf.float32)
img=tf.reshape(img,[image_size,image_size,3])
答案 0 :(得分:0)
您可以在日志中清楚地看到这一点,正如@Mathias Rav在评论中所解释的那样。
paths: {
// paths serve as alias
'npm:': 'https://unpkg.com/'
},
TFrecords的大小是154587:227x227x3
但是,您提供的张量形状的大小是150528:224x224x3
首先检查模型所需的输入尺寸,然后相应地调整图像。 或者,
我希望这会有所帮助。