TensorFlow:在不同输出形状的数据集之间交替

时间:2018-08-24 04:18:53

标签: python tensorflow tensorflow-datasets

我正在尝试将tf.Dataset用于3D图像CNN,其中从训练集和验证集馈入其中的3D图像的形状不同(训练:(64,64,64),验证:(176,176,160))。我什至不知道这样做是可行的,但是我正在基于一篇论文重新创建该网络,并使用经典的feed_dict方法使该网络确实有效。出于性能原因(并且只是为了学习),我试图将网络切换为使用tf.Dataset

我有两个像下面这样构建的数据集和迭代器:

def _data_parser(dataset, shape):
        features = {"input": tf.FixedLenFeature((), tf.string),
                    "label": tf.FixedLenFeature((), tf.string)}
        parsed_features = tf.parse_single_example(dataset, features)

        image = tf.decode_raw(parsed_features["input"], tf.float32)
        image = tf.reshape(image, shape + (1,))

        label = tf.decode_raw(parsed_features["label"], tf.float32)
        label = tf.reshape(label, shape + (1,))
        return image, label

train_datasets = ["train.tfrecord"]
train_dataset = tf.data.TFRecordDataset(train_datasets)
train_dataset = train_dataset.map(lambda x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()

val_datasets = ["validation.tfrecord"]
val_dataset = tf.data.TFRecordDataset(val_datasets)
val_dataset = val_dataset.map(lambda x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()

TensorFlow documentation提供了有关使用reinitializable_iteratorfeedable_iterator在数据集之间进行切换的示例,但它们都在相同输出形状的迭代器之间进行切换,但实际并非如此在这里。

那么我该如何使用tf.Datasettf.data.Iterator在训练集和验证集之间切换?

1 个答案:

答案 0 :(得分:2)

仅在尺寸不匹配的轴上为形状提供未指定的(None)值。例如

import numpy as np
import tensorflow as tf

training_dataset = tf.data.Dataset.from_tensors(np.zeros((64, 64, 64), np.float32)).repeat().batch(4)
validation_dataset = tf.data.Dataset.from_tensors(np.zeros((176, 176, 160), np.float32)).repeat().batch(1)

iterator = tf.data.Iterator.from_structure(
    training_dataset.output_types,
    tf.TensorShape([None, None, None, None]))
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

sess = tf.InteractiveSession()
sess.run(training_init_op)
print(sess.run(next_element).shape)
sess.run(validation_init_op)
print(sess.run(next_element).shape)