使用Tensorflow的数据集时,我必须调用iter.get_next吗?

时间:2018-04-25 14:54:43

标签: tensorflow tensorflow-datasets

我已将代码从基于队列的系统转换为tensorflow的数据集。转换后,我发现精度下降,时间增加。我将此归因于我的错误实现,我目前正在尝试解决可能存在的问题。现在通过这次转换的反复试验,我根据我遇到的一些文章和例子做了一些假设,我只是想确保我当前的实现是正确的,我的假设也是如此。

以前我有大量的图像,我会将它们分批排队,然后用100张图像弹出队列,执行处理和汇总,然后继续。通过我认为可能导致瓶颈的队列加载到内存中,所以当我听说数据集API时,我觉得值得一看。所以我现在检索所有图像信息并将其传递给我的方法,然后我通过数据集批处理方法执行批处理。之前和之后如下所示。我已经读过,没有必要在数据集上调用iter.get_next,因为操作会自动调用它,但是我最后看到的准确性,我对这是否犹豫是否犹豫不决是真是假。目前你可以看到,我只是将iter.initializer作为一个op传递给sess.run和我的其他操作并传递feed_dict。任何见解都会有所帮助,因为我对此有点新鲜。谢谢!

使用队列时的上一个示例函数: (请注意,我会将图像排入blob对象并将该子集传递给此方法)

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0):
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        summary, acc = sess.run([self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

使用数据集API的当前样本函数: (现在在调用之前,我使用所有数据填充我的blob对象并使用下面的批处理功能 - 请注意我从未调用iter.get_next())

def get_summary(self, sess, images, labels, weights, keep_prob = 1.0, batch_size=32):
        dataset = tf.data.Dataset.from_tensor_slices((self._input_images, self._input_labels,
                                                      self._input_weights)).repeat().batch(batch_size)

        iter = dataset.make_initializable_iterator()
        feed_dict = {self._input_images: images, self._input_labels: labels,
                     self._input_weights: weights, self._is_training: False}
        _, summary, acc = sess.run([iter.initializer, self._summary_op, self._accuracy], feed_dict=feed_dict)

        return summary, acc

1 个答案:

答案 0 :(得分:1)

从该代码段开始,您似乎永远不会使用iter中的值,因此它应该对您的摘要没有影响。例如,您应该能够删除创建迭代器的行,并从传递给iter.initializer的列表中删除sess.run()并获得相同的结果。

要回答更广泛的问题"我必须在基于图表的TensorFlow中调用iter.get_next()?":tf.data.Iterator和张量/操作之间必须存在数据流连接您传递给sess.run()以便使用该迭代器中的值。如果您使用的是低级 TensorFlow API,最简单的方法是调用iter.get_next()获取一个或多个tf.Tensor个对象,然后将这些张量用作模型的输入。

但是,如果您使用的是高级 tf.estimator API,则input_fn可以返回tf.data.Dataset而无需创建tf.data.Iterator(或者调用Iterator.get_next(),Estimator API将负责创建迭代器并为您调用get_next()