在测试时替换输入管道(没有占位符的tf.contrib.data)

时间:2017-07-25 21:12:33

标签: tensorflow

我在培训期间使用tf.contrib.data函数作为输入管道(没有占位符)。我的问题是如何在测试时重复使用经过训练的模型并输入新数据?问题类似于this one,除了我不想在测试中使用占位符 - 我的测试数据集可能非常大,并且也应该避免占位符的减速。

有没有办法在测试时用新的输入管道替换输入管道?

1 个答案:

答案 0 :(得分:0)

我不确定是否有最佳方法可以解决这个问题,但这就是我解决它的方法:

在我的模型中,我使用的是简单的MLP,因此我的model()函数中包含这样的行:

train_layer = tf.add(tf.matmul(x_train, weights['w1']), biases['b1'])
train_layer = tf.nn.relu(train_layer)
test_layer = tf.add(tf.matmul(x_test, weights['w1']), biases['b1'])
test_layer = tf.nn.relu(test_layer)

如您所见,我有两个输入x_trainx_test。这些是从tf.contrib.data数据集迭代器获取批量数据的句柄:

x_train, x_train_labels = train_iter.get_next()
x_test, x_test_labels = test_iter.get_next()

所以我基本上在同一个图中有两个数据流,执行完全相同的操作。我还有两个模型输出mlp_trainmlp_test,具体取决于模型是使用x_train还是x_test输入进行评估。

现在:如果您使用mlp_train输出创建优化器,并使用mlp_test输出创建测试指标,则只需运行:sess.run(optimiser)即可在训练数据集,sess.run(test_metrics)在测试数据集上测试您的系统,您永远不需要使用feed_dict。

编辑:我读到了关于使用&#34;在模型培训时无法获得的数据的评论&#34;,我认为这个答案不能满足要求。< / p>