数据集API,无法为形状为((64,603)''的Tensor'占位符:0'输入形状(41,603)的值

时间:2018-12-20 14:13:50

标签: python tensorflow deep-learning

我使用了Dataset API来消耗tfrecord数据。这是代码:

def read_tfrecord(fn, batch_size, epoch):
    dataset = tf.data.TFRecordDataset(fn)

    def parser(record):
        features = {
            'para': tf.FixedLenFeature([5], tf.float32),
            'label': tf.FixedLenFeature([1], tf.float32),
            'spec': tf.FixedLenFeature([603], tf.float32)
        }
        parsed = tf.parse_single_example(record, features)
        para = parsed['para']
        label = parsed['label']
        spec = parsed['spec']

        return para, label, spec

    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size=100000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(epoch)
    iterator = dataset.make_one_shot_iterator()
    paras, labels, specs = iterator.get_next()

    return paras, labels, specs

train_dataset = tfrecord.read_tfrecord('./train_data.tfrecord', 64, 5)


with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(init)
    train_l_ = []
    it_ = []
    it = 0
    while True:
        try:
            para_batch, label_batch, spec_batch = sess.run(train_dataset)
            it += 1
            _, train_l = sess.run([fw_op, train_loss], feed_dict={strupata: para_batch, r_spec: spec_batch})
            train_l_.append(train_l)
            it_.append(it)

            if it % 10 == 0:
                print(' Iteration_{}, train_loss: {}'.format(it, train_l))
        except tf.errors.OutOfRangeError:
            break

但是,有41个数据点无法读取。数据集也没有重复5次。 ValueError:无法为张量为“(64,603)”的张量“占位符:0”输入形状(41,603)的值

0 个答案:

没有答案