Q1:我正在关注Recurrent Neural Networks上的this tutorial,我想知道你为什么需要在代码的以下部分创建feed_dict
:
def run_epoch(session, model, eval_op=None, verbose=False):
state = session.run(model.initial_state)
fetches = {
"cost": model.cost,
"final_state": model.final_state,
}
if eval_op is not None:
fetches["eval_op"] = eval_op
for step in range(model.input.epoch_size):
feed_dict = {}
for i, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[i].c
feed_dict[h] = state[i].h
vals = session.run(fetches, feed_dict)
我测试了,似乎如果删除这部分代码,代码也会运行:
def run_epoch(session, model, eval_op=None, verbose=False):
fetches = {
"cost": model.cost,
"final_state": model.final_state,
}
if eval_op is not None:
fetches["eval_op"] = eval_op
for step in range(model.input.epoch_size):
vals = session.run(fetches)
所以我的问题是,为什么在提供新批量数据后需要将初始状态重置为零?
Q2:此外,据我所知,使用feed_dict
被认为是缓慢的。这就是为什么建议使用tf.data
API提供数据的原因。在这种情况下使用feed_dict
也是一个问题吗?如果是这样,如何避免在此示例中使用feed_dict
。
UPD:非常感谢@jdehesa的详细回复。它帮助很大!在我结束这个问题并接受你的回答之前,你能澄清一下你提到回答Q1的一点。
我现在看到了feed_dict
的目的。但是,我不确定这是教程中实现的内容。从你说的话:
在每个纪元的开头,代码首先采用默认的"零状态"然后进入一个循环,其中当前状态为初始状态,运行模型并将输出状态设置为下一次迭代的新当前状态。
我刚看了一下本教程的the source code,我没有看到输出状态被设置为下一次迭代的新当前状态。它是隐含地在某处完成的还是我错过了什么?
我也许在理论方面也缺少一些东西。为了确保我理解正确,这里有一个简单的例子。假设输入数据是一个存储0到120整数值的数组。我们将批量大小设置为5
,一批中的数据点数为24
,以及时间步长数展开的RNN为10
。在这种情况下,您只能在0
到20
的时间点使用数据点。然后分两步处理数据(model.input.epoch_size = 2
)。当您遍历model.input.epoch_size
:
state = session.run(model.initial_state)
# ...
for step in range(model.input.epoch_size):
feed_dict = {}
for i, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[i].c
feed_dict[h] = state[i].h
vals = session.run(fetches, feed_dict)
您提供了一批这样的数据:
> Iteration (step) 1:
x:
[[ 0 1 2 3 4 5 6 7 8 9]
[ 24 25 26 27 28 29 30 31 32 33]
[ 48 49 50 51 52 53 54 55 56 57]
[ 72 73 74 75 76 77 78 79 80 81]
[ 96 97 98 99 100 101 102 103 104 105]]
y:
[[ 1 2 3 4 5 6 7 8 9 10]
[ 25 26 27 28 29 30 31 32 33 34]
[ 49 50 51 52 53 54 55 56 57 58]
[ 73 74 75 76 77 78 79 80 81 82]
[ 97 98 99 100 101 102 103 104 105 106]]
> Iteration (step) 2:
x:
[[ 10 11 12 13 14 15 16 17 18 19]
[ 34 35 36 37 38 39 40 41 42 43]
[ 58 59 60 61 62 63 64 65 66 67]
[ 82 83 84 85 86 87 88 89 90 91]
[106 107 108 109 110 111 112 113 114 115]]
y:
[[ 11 12 13 14 15 16 17 18 19 20]
[ 35 36 37 38 39 40 41 42 43 44]
[ 59 60 61 62 63 64 65 66 67 68]
[ 83 84 85 86 87 88 89 90 91 92]
[107 108 109 110 111 112 113 114 115 116]]
在每次迭代中,您构造一个新的feed_dict
,其初始状态为周期单位为零。因此,您假设在每个步骤中从头开始处理序列。这是对的吗?
答案 0 :(得分:1)
Q1。 feed_dict
用于设置周期性单位的初始状态。默认情况下,每次调用run
周期性单位时,处理初始“零”状态的数据。但是,如果您的序列很长,您可能需要将它们分成几个步骤。重要的是,在每个步骤之后,保存循环单位的最终状态并输入下一步的初始状态,否则就好像下一步再次是序列的开始(特别是,如果你的输出只是处理整个序列后网络的最终输出,就像在最后一步之前丢弃所有数据一样)。在每个纪元的开始,代码首先采用默认的“零状态”,然后继续进行循环,其中当前状态作为初始状态给出,模型运行并且输出状态被设置为下一个的新当前状态迭代。
Q2。声称“feed_dict
缓慢”可能会有些误导,被视为一般性的真相(我不是责怪你说的,我见过它也很多次)。 feed_dict
的问题在于它的功能是将非TensorFlow数据(通常是NumPy数据)带入TensorFlow世界。这并不是很糟糕,只是需要一些额外的时间来移动数据,这在涉及大量数据时尤其值得注意。例如,如果要通过feed_dict
输入一批图像,则需要从磁盘加载它们,解码它们,将其转换为大的NumPy数组并将其传递到feed_dict
,然后TensorFlow将将所有数据复制到会话中(GPU内存或其他);所以你会在内存和额外的内存交换中存储两份数据。 tf.data
有帮助,因为它可以在TensorFlow中执行所有操作(这也减少了Python / C行程的数量,有时通常更方便)。在你的情况下,通过feed_dict
提供的是经常性单位的初始状态。除非你有几个非常大的重复层,否则我认为性能影响可能相当小。但 可能是为了在这种情况下避免feed_dict
,你需要有一组TensorFlow变量保持当前状态,设置循环单位以使用它们的输出初始状态(initial_state
参数tf.nn.dynamic_rnn
)并使用其最终状态更新变量值;然后在每个新批次上,您必须再次将变量重新初始化为“零”状态。但是,我会确保在沿着该路线行进之前这将带来显着的好处(例如,使用和不使用feed_dict
测量运行时间,即使结果是错误的)。
编辑:
作为更新的说明,我在这里复制了代码的相关行:
state = session.run(model.initial_state)
fetches = {
"cost": model.cost,
"final_state": model.final_state,
}
if eval_op is not None:
fetches["eval_op"] = eval_op
for step in range(model.input.epoch_size):
feed_dict = {}
for i, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[i].c
feed_dict[h] = state[i].h
vals = session.run(fetches, feed_dict)
cost = vals["cost"]
state = vals["final_state"]
costs += cost
iters += model.input.num_steps
在一个纪元的开头,state
取model.initial_state
的值,除非给出feed_dict
替换其值,否则它将是默认的“零”初始状态值。 fetches
是一个稍后传递给session.run
的字典,因此它返回另一个字典,其中({等式}}将保存最终状态值。然后,在每个步骤中,创建"final_state"
,用feed_dict
中的数据替换initial_state
张量值,并使用state
调用run
在feed_dict
中检索张量的值,然后fetches
保留vals
调用的输出。第run
行替换了state = vals["final_state"]
的内容,它是我们当前的状态值,具有上次运行的输出状态;所以在下一次迭代state
将保留前一个最后一个状态的值,因此网络将继续“仿佛”一次性给出整个序列。在下一次调用feed_dict
时,run_epoch
将再次初始化为默认值state
,并且该过程将再次从“零”开始。