Feed Iterator to Tensorflow Graph

时间:2018-01-19 21:56:20

标签: python tensorflow tensorflow-datasets

我使用tf.data.Iterator创建了make_one_shot_iterator(),并希望用它来训练我的(现有)模型。

目前我的培训看起来像这样

input_node = tf.placeholder(tf.float32, shape=(None, height, width, channels))

net = models.ResNet50UpProj({'data': input_node}, batch_size, keep_prob=True,is_training=True)

labels = tf.placeholder(tf.float32, shape=(None, width, height, 1))
huberloss = tf.losses.huber_loss(predictions=net.get_output(),labels=labels)

然后致电

sess.run(train_op, feed_dict={labels:output_img, input_node:input_img})

经过训练,我可以得到这样的预测:

pred = sess.run(net.get_output(), feed_dict={input_node: img})

现在使用迭代器我尝试了类似的东西

next_element = iterator.get_next()

传递输入数据如下:

net = models.ResNet50UpProj({'data': next_element[0]}, batch_size, keep_prob=True,is_training=True)

像这样定义损失函数:

huberloss = tf.losses.huber_loss(predictions=net.get_output(),labels=next_element[1])

执行训练就像在每次调用时自动遍历迭代器一样简单:

sess.run(train_op)

我的问题是:训练后我无法做出任何预测。或者更确切地说,我不知道在我的情况下使用迭代器的正确方法。

1 个答案:

答案 0 :(得分:3)

解决方案1:仅为推理创建单独的子图,尤其是当您具有批量规范化和退出(is_training=False)等层时。

# The following code assumes that you create variables with `tf.get_variable`. 
# If you create variables manually, you have to reuse them manually.
with tf.variable_scope('somename'):
    net = models.ResNet50UpProj({'data': next_element[0]}, batch_size, keep_prob=True, is_training=True)
with tf.variable_scope('somename', reuse=True):
    net_for_eval = models.ResNet50UpProj({'data': some_placeholder_or_inference_data_iterator}, batch_size, keep_prob=True, is_training=False)

解决方案2:使用feed_dict。您可以使用Feed dict替换几乎所有tf.Tensor,而不只是tf.placeholder

sess.run(huber_loss, {next_element[0]: inference_image, next_element[1]: inference_labels})