我使用了Dataset API来消耗tfrecord数据。这是代码:
def read_tfrecord(fn, batch_size, epoch):
dataset = tf.data.TFRecordDataset(fn)
def parser(record):
features = {
'para': tf.FixedLenFeature([5], tf.float32),
'label': tf.FixedLenFeature([1], tf.float32),
'spec': tf.FixedLenFeature([603], tf.float32)
}
parsed = tf.parse_single_example(record, features)
para = parsed['para']
label = parsed['label']
spec = parsed['spec']
return para, label, spec
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
paras, labels, specs = iterator.get_next()
return paras, labels, specs
train_dataset = tfrecord.read_tfrecord('./train_data.tfrecord', 64, 5)
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
sess.run(init)
train_l_ = []
it_ = []
it = 0
while True:
try:
para_batch, label_batch, spec_batch = sess.run(train_dataset)
it += 1
_, train_l = sess.run([fw_op, train_loss], feed_dict={strupata: para_batch, r_spec: spec_batch})
train_l_.append(train_l)
it_.append(it)
if it % 10 == 0:
print(' Iteration_{}, train_loss: {}'.format(it, train_l))
except tf.errors.OutOfRangeError:
break
但是,有41个数据点无法读取。数据集也没有重复5次。 ValueError:无法为张量为“(64,603)”的张量“占位符:0”输入形状(41,603)的值