我正在尝试将tf.Dataset
用于3D图像CNN,其中从训练集和验证集馈入其中的3D图像的形状不同(训练:(64,64,64),验证:(176,176,160))。我什至不知道这样做是可行的,但是我正在基于一篇论文重新创建该网络,并使用经典的feed_dict
方法使该网络确实有效。出于性能原因(并且只是为了学习),我试图将网络切换为使用tf.Dataset
。
我有两个像下面这样构建的数据集和迭代器:
def _data_parser(dataset, shape):
features = {"input": tf.FixedLenFeature((), tf.string),
"label": tf.FixedLenFeature((), tf.string)}
parsed_features = tf.parse_single_example(dataset, features)
image = tf.decode_raw(parsed_features["input"], tf.float32)
image = tf.reshape(image, shape + (1,))
label = tf.decode_raw(parsed_features["label"], tf.float32)
label = tf.reshape(label, shape + (1,))
return image, label
train_datasets = ["train.tfrecord"]
train_dataset = tf.data.TFRecordDataset(train_datasets)
train_dataset = train_dataset.map(lambda x: _data_parser(x, (64, 64, 64)))
train_dataset = train_dataset.batch(batch_size) # batch_size = 16
train_iterator = train_dataset.make_initializable_iterator()
val_datasets = ["validation.tfrecord"]
val_dataset = tf.data.TFRecordDataset(val_datasets)
val_dataset = val_dataset.map(lambda x: _data_parser(x, (176, 176, 160)))
val_dataset = val_dataset.batch(1)
val_iterator = val_dataset.make_initializable_iterator()
TensorFlow documentation提供了有关使用reinitializable_iterator
或feedable_iterator
在数据集之间进行切换的示例,但它们都在相同输出形状的迭代器之间进行切换,但实际并非如此在这里。
那么我该如何使用tf.Dataset
和tf.data.Iterator
在训练集和验证集之间切换?
答案 0 :(得分:2)
仅在尺寸不匹配的轴上为形状提供未指定的(None
)值。例如
import numpy as np
import tensorflow as tf
training_dataset = tf.data.Dataset.from_tensors(np.zeros((64, 64, 64), np.float32)).repeat().batch(4)
validation_dataset = tf.data.Dataset.from_tensors(np.zeros((176, 176, 160), np.float32)).repeat().batch(1)
iterator = tf.data.Iterator.from_structure(
training_dataset.output_types,
tf.TensorShape([None, None, None, None]))
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
sess = tf.InteractiveSession()
sess.run(training_init_op)
print(sess.run(next_element).shape)
sess.run(validation_init_op)
print(sess.run(next_element).shape)