tf.contrib.seq2seq.gather_tree如何运作?

时间:2018-01-03 12:49:14

标签: tensorflow beam-search

gather_treecontrib.seq2seq的工作原理究竟如何?我可以看到它需要预测的ids和梁父ID并以某种方式返回最后的光束,但实际上在引擎盖下面是什么?似乎没有任何我可以检查的Python代码库来弄清楚它。 API不是很清楚;

tf.contrib.seq2seq.gather_tree是否有任何代码来源?我正在使用TensorFlow 1.3并查看gen_beam_search_ops.py内部似乎没有帮助。

1 个答案:

答案 0 :(得分:0)

代码详述如下:

def gather_tree_py(values, parents):
  """Gathers path through a tree backwards from the leave nodes. Used
  to reconstruct beams given their parents."""

  beam_length = values.shape[0]
  num_beams = values.shape[1]
  res = np.zeros_like(values)
  res[-1, :] = values[-1, :]
  for beam_id in range(num_beams):
    parent = parents[-1][beam_id]
    for level in reversed(range(beam_length - 1)):
      res[level, beam_id] = values[level][parent]
      parent = parents[level][parent]
  return np.array(res).astype(values.dtype)


def gather_tree(values, parents):
  """Tensor version of gather_tree_py"""

  res = tf.py_func(
      func=gather_tree_py, inp=[values, parents], Tout=values.dtype)
  res.set_shape(values.get_shape().as_list())
  return res

github: seq2seq beam_search