用tf.data API

时间:2018-04-10 20:24:46

标签: python tensorflow tensorflow-datasets

我有一个现有的TensorFlow模型,它使用tf.placeholder作为模型输入,并使用tf.Session()的feed_dict参数运行以输入数据。以前,整个数据集都被读入内存并以这种方式传递。

我想使用更大的数据集并利用tf.data API的性能改进。我已经从中定义了一个tf.data.TextLineDataset和一次性迭代器,但是我很难弄清楚如何将数据导入模型来训练它。

首先,我尝试将feed_dict定义为从占位符到iterator.get_next()的字典,但这给了我一个错误,说明feed的值不能是tf.Tensor对象。更多的挖掘让我明白这是因为iterator.get_next()返回的对象已经是图形的一部分了,不像你将它提供给feed_dict那样 - 而且我不应该尝试使用feed_dict无论如何出于性能原因。

所以现在我已经摆脱了输入tf.placeholder并将其替换为定义我的模型的类的构造函数的参数;在我的训练代码中构建模型时,我将iterator.get_next()的输出传递给该参数。这似乎有点笨拙,因为它打破了模型定义与数据集/培训程序之间的分离。我现在收到一个错误,说Tensor代表(我相信)我的模型的输入必须来自与iterator.get_next()的Tensor相同的图形。

我采用这种方法走在正确的轨道上,只是在设置图表和会话或其他类似方面做错了什么? (数据集和模型都在会话之外初始化,并且在我尝试创建之前发生错误。)

或者我完全偏离这个并且需要做一些不同的事情,比如使用Estimator API并在输入函数中定义所有内容?

以下是一些展示最小示例的代码:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

2 个答案:

答案 0 :(得分:6)

我也需要一点时间。你正走在正确的轨道上。整个数据集定义只是图表的一部分。我通常将它创建为与我的Model类不同的类,并将数据集传递给Model类。我指定要在命令行上加载的数据集类,然后动态加载该类,从而模块化地解耦数据集和图形。

请注意,您可以(并且应该)为数据集中的所有张量命名,当您通过您需要的各种转换传递数据时,它确实有助于简化操作。

您可以编写简单的测试用例,从iterator.get_next()中提取样本并显示它们,您可以使用sess.run(next_element_tensor),而不是feed_dict,因为您已经正确指出。

一旦你了解它,你可能会开始喜欢数据集输入管道。它迫使你很好地模块化你的代码,并迫使它进入一个易于单元测试的结构。

请务必阅读开发者指南,其中有大量示例:

https://www.tensorflow.org/programmers_guide/datasets

我要注意的另一件事是使用此管道处理火车和测试数据集是多么容易。这一点非常重要,因为您经常对您未在测试数据集上执行的训练数据集执行数据扩充,from_string_handle允许您这样做,并在上面的指南中有清楚描述。

答案 1 :(得分:2)

我给出的原始代码中模型构造函数中的行tf.reset_default_graph()导致了它。删除修复它。