无法理解tf.nn.raw_rnn

时间:2018-05-08 22:45:03

标签: python-3.x tensorflow recurrent-neural-network rnn tensorflow-slim

tf.nn.raw_rnn的{​​{3}}中,当第一次运行loop_fn时,我们发出结构作为loop_fn的第三个输出。

稍后,emit_structure用于将tf.zeros_like(emit_structure)复制到由emit = tf.where(finished, tf.zeros_like(emit_structure), emit)完成的小批量条目。

我对谷歌的缺乏理解或糟糕的文档是:发布结构是None所以tf.where(finished, tf.zeros_like(emit_structure), emit)会抛出一个ValueError,因为tf.zeros_like(None)这样做。有人可以填写我在这里缺少的东西吗?

1 个答案:

答案 0 :(得分:1)

是的,这个文档在这个地方比较混乱。如果你看一下tf.nn.raw_rnn的内部,那么关键术语就是"在伪代码" 中,因此文档中的示例并不准确。

确切的源代码如下所示(根据您的tensorflow版本可能会有所不同):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

因此它处理emit_structure is None时的情况并且只取值cell.output_size。这就是为什么没有真正破坏的原因。