我一直试图在Tensorflow Serving中使用基于seq2seq的模型。
在没有波束搜索的情况下似乎可以正常工作。
但是,当启用波束搜索时,推断中还会涉及tf.py_func
。
由于tf.py_func
无法序列化为GraphDef,因此我无法在Tensorflow Serving中使用它。
我需要将使用过的here的tf.py_func
转换为纯TF操作。
该py_func的代码段
def gather_tree_py(values, parents):
beam_length = values.shape[0]
num_beams = values.shape[1]
res = np.zeros_like(values)
res[-1, :] = values[-1, :]
for beam_id in range(num_beams):
parent = parents[-1][beam_id]
for level in reversed(range(beam_length - 1)):
res[level, beam_id] = values[level][parent]
parent = parents[level][parent]
return np.array(res).astype(values.dtype)
def gather_tree(values, parents):
res = tf.py_func(func=gather_tree_py, inp=[values, parents], Tout=values.dtype)
因为我是一个初学者,所以在转换方面我需要一些帮助。
我该如何进行转换?
也欢迎参阅此issue。