推断使用Tensorflow Dataset API训练的模型的新输入

时间:2018-05-11 16:05:20

标签: python tensorflow

我从数据集API训练一个tensorflow(1.7)模型,如下所示:

features_data_ph = tf.placeholder(tf.int32, [None, None, max_sent_len], 'features_data_ph')

mode_ph = tf.placeholder(tf.int32, name='mode_ph')

labels_data_ph = tf.placeholder(tf.int32, [None, num_classes], 'labels_data_ph')

train_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
train_dataset = train_dataset.shuffle(buffer_size=100000).batch(batch_size)
train_iterator = train_dataset.make_initializable_iterator()

val_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
val_iterator = val_dataset.make_initializable_iterator()

input_tensor, labels_tensor = tf.case(
                {
                    tf.equal(mode_ph, 0): train_iter.get_next,
                    tf.equal(mode_ph, 1): val_iter.get_next,
                }
            )

logits = model(input_tensor)
loss = get_loss(logits, labels_tensor)
...
# start of training epoch
session.run(train_iterator.initializer, feed_dict={
    features_data_ph: train_features,
    labels_data_ph: train_labels
})
...
# new validation after some steps
session.run(val_iterator.initializer, feed_dict={
    features_data_ph: val_features,
    labels_data_ph: val_labels
})

现在您可以看到,input_tensor取决于数据集。所以我不能只提供一个新的numpy数组来推断数据集中没有的数据。

到目前为止我所做的是创建一个第三个数据集,用于保存推理数据(并将tf.equal(mode_ph, 2): infer_iter.get_next添加到tf.case

有没有更好的方法来推断现有数据集中没有的数据?使用val_dataset会覆盖它包含的数据

0 个答案:

没有答案