我从数据集API训练一个tensorflow(1.7)模型,如下所示:
features_data_ph = tf.placeholder(tf.int32, [None, None, max_sent_len], 'features_data_ph')
mode_ph = tf.placeholder(tf.int32, name='mode_ph')
labels_data_ph = tf.placeholder(tf.int32, [None, num_classes], 'labels_data_ph')
train_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
train_dataset = train_dataset.shuffle(buffer_size=100000).batch(batch_size)
train_iterator = train_dataset.make_initializable_iterator()
val_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
val_iterator = val_dataset.make_initializable_iterator()
input_tensor, labels_tensor = tf.case(
{
tf.equal(mode_ph, 0): train_iter.get_next,
tf.equal(mode_ph, 1): val_iter.get_next,
}
)
logits = model(input_tensor)
loss = get_loss(logits, labels_tensor)
...
# start of training epoch
session.run(train_iterator.initializer, feed_dict={
features_data_ph: train_features,
labels_data_ph: train_labels
})
...
# new validation after some steps
session.run(val_iterator.initializer, feed_dict={
features_data_ph: val_features,
labels_data_ph: val_labels
})
现在您可以看到,input_tensor
取决于数据集。所以我不能只提供一个新的numpy数组来推断数据集中没有的数据。
到目前为止我所做的是创建一个第三个数据集,用于保存推理数据(并将tf.equal(mode_ph, 2): infer_iter.get_next
添加到tf.case
)
有没有更好的方法来推断现有数据集中没有的数据?使用val_dataset
会覆盖它包含的数据