我有以下代码:
import tensorflow as tf
filenames = # list of N filenames of input data
labels = # NxM array where each row represents a data in one-hot-encoding form
noise = # list of N filenames of noise used for data augmentation
feature_placeholder = tf.placeholder(tf.string)
noise_placeholder = tf.placeholder(tf.string)
label_placeholder = tf.placeholder(labels.dtype, shape=(None, labels.shape[1]))
train_dataset = tf.data.Dataset.from_tensor_slices((feature_placeholder, label_placeholder, noise_placeholder))
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.map(parse_fn) # some function for data parsing and data augmentation
train_dataset = train_dataset.batch(100)
train_dataset = train_dataset.prefetch(buffer_size=1)
iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
features, labels, _ = iterator.get_next() # don't care about the noise
train_iterator = iterator.make_initializer(train_dataset, name='train_iterator')
# other codes to run the model
但是当我运行程序时,出现了错误:
InvalidArgumentError (see above for traceback):
All components must have the same size in the 0th dimension
[[node TensorSliceDataset (defined at train.py:141) = TensorSliceDataset[Toutput_types=[DT_STRING, DT_DOUBLE, DT_STRING], _class=["loc:@IteratorV2"], output_shapes=[<unknown>, [15], <unknown>]
发生在train_iterator = iterator.make_initializer(train_dataset, name='train_iterator')
。我怀疑它来自于我定义占位符形状的方式,但无法弄清楚如何纠正它。有人知道吗?