我将first steps with tensor flow视为google machine learning crahs course的一部分并且已经混淆了。我的理解是(如果我错了,请纠正我):
my_input_fn
,它将数据格式化为相关的TensorFlow结构.Tensor train
电话。 my_input_fn
以获取连续批量数据以调整模型。 (现在对此立即怀疑) my_input_fn
在此处定义:
def my_input_fn(features, targets, batch_size=1, shuffle=True, num_epochs=None):
"""Trains a linear regression model of one feature.
Args:
features: pandas DataFrame of features
targets: pandas DataFrame of targets
batch_size: Size of batches to be passed to the model
shuffle: True or False. Whether to shuffle the data.
num_epochs: Number of epochs for which data should be repeated. None = repeat indefinitely
Returns:
Tuple of (features, labels) for next data batch
"""
# Convert pandas data into a dict of np arrays.
features = {key:np.array(value) for key,value in dict(features).items()}
# Construct a dataset, and configure batching/repeating
ds = Dataset.from_tensor_slices((features,targets)) # warning: 2GB limit
ds = ds.batch(batch_size).repeat(num_epochs)
# Shuffle the data, if specified
if shuffle:
ds = ds.shuffle(buffer_size=10000)
# Return the next batch of data
features, labels = ds.make_one_shot_iterator().get_next()
return features, labels
从我对my_input_fn
的阅读中,我不明白这是怎么发生的。我只有python的基本知识,但我对函数的读取是每次调用它都会重新初始化pandas帧中的张量结构,得到一个迭代器然后返回它的第一个元素。每次被召唤。当然,在这个例子的情况下,如果数据被洗牌(默认情况下是这样)并且数据集很大,那么不太可能你将获得100步的重复数据,但这种气味草率编程(即如果它没有改组,它总会返回相同的第一个训练数据集)所以我怀疑是这种情况。
我的下一个怀疑是one_shot_iterator().get_next()
电话正在做一些有趣/古怪/棘手的事情。就像返回某种后期eval结构一样,允许train
函数从自身枚举到下一批,而不是重新调用my_input_fn
?
但老实说,我想澄清这一点,因为在这个阶段,我想要考虑的时间比我想要的更多,我不再接近理解。
我的研究尝试导致了进一步的混乱。
教程建议阅读this - 在某一点上,它表示“每个估算器的训练,评估和预测方法都需要输入函数来返回包含张量流张量的(特征,标签)对。”好的,这与我最初的想法是一致的。基本上是TensorFlow结构中打包的示例和标签。
然后它显示了它返回的结果,它就是这样的东西(例子):
({
'SepalLength': <tf.Tensor 'IteratorGetNext:2' shape=(?,) dtype=float64>,
'PetalWidth': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=float64>,
'PetalLength': <tf.Tensor 'IteratorGetNext:0' shape=(?,) dtype=float64>,
'SepalWidth': <tf.Tensor 'IteratorGetNext:3' shape=(?,) dtype=float64>},
Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64))
在代码实验室中,my_input_fn(my_feature, targets)
返回:
({'total_rooms': <tf.Tensor 'IteratorGetNext:0' shape=(?,) dtype=float64>},
)
我没有想法如何做到这一点。我对tensors的解读没有提到这样的事情。我甚至不知道如何开始用我的基本Python和不存在的TensorFlow知识来查询它。
one shot iterator的文档说它创建了一个枚举元素的迭代器。再次,这符合我的想法。
get_next文档说:
返回包含下一个元素的tf.Tensors的嵌套结构。
我不知道如何解析这个问题。什么样的嵌套结构?我的意思是它看起来像一个元组,但为什么你不会说元组?是什么决定这个?在哪里描述?当然这很重要吗?
我在这里误解了什么?
(对于一个据称不需要先前TensorFlow知识的课程,谷歌机器学习速成课程让我感到非常愚蠢。我真的很好奇其他人在我的情况下是怎么回事。)
答案 0 :(得分:7)
输入函数(在本例中为my_input_function
)是不是重复调用。它被调用一次,创建一堆tensorflow操作(用于创建数据集,对其进行混洗等),最后返回迭代器的get_next
操作。这个op 将重复调用,但它所做的就是迭代数据集。你在my_input_function
中所做的事情(如洗牌,批量,重复)只发生过一次。
一般情况下:使用Tensorflow程序时,你必须习惯这样一个事实:它们的工作方式与#34; normal&#34; Python程序。您编写的大多数代码(特别是前面带tf.
的代码)只会执行一次以构建计算图,然后此图执行多次。
编辑:但是,实验tf.eager
API(据说可以完全集成在TF 1.7中)可以完全改变这一点,即在你编写它们时执行(更像是numpy)。这应该允许更快的实验。
逐步完成输入功能:首先使用从&#34;张量切片&#34;中创建的数据集开始。 (例如numpy数组)。然后调用batch
方法。这基本上创建一个新的数据集,其元素是原始数据集的元素批次。类似地,重复和混洗也会创建新的数据集(确切地说,它们创建操作,一旦它们作为计算图的一部分实际执行,将创建这些数据集)。最后,在批处理的,重复的,混洗的数据集上返回一个迭代器。只有这个迭代器的get_next
操作将重复执行,返回数据集的新元素,直到它耗尽为止。
编辑:确实iterator.get_next()
只返回操作。只有在tf.Session
中运行此操作后才会执行迭代。
至于你拥有的输出&#34;不知道该怎么做&#34;:不确定你的问题是什么,但你发布的只是将字符串映射到张量的字符串。张量自动获取与产生它们的操作相关的名称(iterator.get_next
),并且它们的形状未知,因为批量大小可以变化 - 即使指定它,如果批量大小不对,最后一批可能更小&# 39; t均匀划分数据集大小(例如,数据集包含10个元素,批处理大小为4 - 最后一批将为大小2)。张量形状中的?
个元素表示未知的维度
编辑:关于命名:操作接收默认名称。但是,在这种情况下,它们都会收到相同的默认名称(IteratorGetNext
),但不能有多个具有相同名称的操作。因此,Tensorflow会自动附加整数以使名称唯一。这就是全部!
至于#34;嵌套结构&#34;:输入函数通常与tf.estimator
一起使用,它需要一个相当简单的输入结构(包含Tensor或Tensors dict作为输入的元组,以及Tensor as输出,如果我没有弄错)。但是,通常,输入函数支持更复杂的嵌套输出结构,例如(a, (tuple, of), (tuples, (more, tuples, elements), and), words)
。请注意,这是一个输出的结构,即一个&#34;步骤&#34;迭代器(例如一批数据)。反复调用此操作将枚举整个数据集
编辑:输入函数返回的结构仅由该函数决定!例如。来自张量切片的数据集将返回元组,其中第n个元素是第n个&#34;张量切片&#34;。像dataset.zip
这样的函数就像Python等价物一样工作。如果您将采用具有结构(e1,e2)的数据集并使用数据集(e3,)将其压缩,您将得到((e1,e2),e3)。
需要什么格式取决于应用程序。原则上,您可以提供任何格式,然后接收此输入的代码可以对其执行任何操作。但是,正如我所说,最常见的用法可能是在tf.estimator
的上下文中,你的输入函数应该返回一个元组(特征,标签),其中的特征是张量的张量或词典(如在你的情况下)和标签也是张量的张量或词典。如果其中一个是dict,则模型函数负责从那里获取正确的值/张量。
一般来说,我会建议你玩这个东西。查看the tf.data API,当然还有the Programmer's Guide。创建一些数据集/输入函数,只需启动一个会话并重复运行iterator.get_next()
操作。看看那里出了什么。尝试所有不同的转换,例如zip
,take
,padded_batch
...在不需要对此数据进行任何实际操作的情况下查看它,可以让您更好地理解。