在tf.nn.raw_rnn
的{{3}}中,当第一次运行loop_fn
时,我们发出结构作为loop_fn
的第三个输出。
稍后,emit_structure用于将tf.zeros_like(emit_structure)
复制到由emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
完成的小批量条目。
我对谷歌的缺乏理解或糟糕的文档是:发布结构是None
所以tf.where(finished, tf.zeros_like(emit_structure), emit)
会抛出一个ValueError,因为tf.zeros_like(None)
这样做。有人可以填写我在这里缺少的东西吗?
答案 0 :(得分:1)
是的,这个文档在这个地方比较混乱。如果你看一下tf.nn.raw_rnn
的内部,那么关键术语就是"在伪代码" 中,因此文档中的示例并不准确。
确切的源代码如下所示(根据您的tensorflow版本可能会有所不同):
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
因此它处理emit_structure is None
时的情况并且只取值cell.output_size
。这就是为什么没有真正破坏的原因。