tensorflow py_func是否支持多个扩充?

时间:2018-01-24 13:55:12

标签: tensorflow

就像标题一样。如果有,请举例说明如何将多个增强功能传递给py_func的第一个增强功能。谢谢

1 个答案:

答案 0 :(得分:0)

是的,您只需传递list个张量,然后将其映射到您函数中的对象。同样,您需要为return语句中的每个输出传递输出类型列表:

import tensorflow as tf

def my_python_function(arg1, arg2):
  return arg1, arg2, arg1+arg2

input_tensor_1 = tf.constant([0,1,2,3,4,5])
input_tensor_2 = tf.constant([2,2,2,2,2,2])

res = tf.py_func(my_python_function, [input_tensor_1, input_tensor_2], [tf.int32, tf.int32, tf.int32])

with tf.Session() as sess:
    print(sess.run(res))

运行上一个示例:

Out[17]: 
[array([0, 1, 2, 3, 4, 5]),
 array([2, 2, 2, 2, 2, 2]),
 array([2, 3, 4, 5, 6, 7])]