将tf.py_func转换为原生Tensorflow操作

时间:2019-03-14 09:27:18

标签: python tensorflow tensorflow-serving seq2seq

我一直试图在Tensorflow Serving中使用基于seq2seq的模型。

在没有波束搜索的情况下似乎可以正常工作。

但是,当启用波束搜索时,推断中还会涉及tf.py_func

由于tf.py_func无法序列化为GraphDef,因此我无法在Tensorflow Serving中使用它。

我需要将使用过的heretf.py_func转换为纯TF操作。

该py_func的代码段

def gather_tree_py(values, parents):
  beam_length = values.shape[0]
  num_beams = values.shape[1]
  res = np.zeros_like(values)
  res[-1, :] = values[-1, :]
  for beam_id in range(num_beams):
    parent = parents[-1][beam_id]
    for level in reversed(range(beam_length - 1)):
      res[level, beam_id] = values[level][parent]
      parent = parents[level][parent]
  return np.array(res).astype(values.dtype)

def gather_tree(values, parents):
  res = tf.py_func(func=gather_tree_py, inp=[values, parents], Tout=values.dtype)

因为我是一个初学者,所以在转换方面我需要一些帮助。

我该如何进行转换?

也欢迎参阅此issue

0 个答案:

没有答案