我定义了一个函数fn
来计算每一行的结果,并且我定义了如下代码:
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
a = np.array([[1, 2], [2, 1], [1, 1]])
b = np.array([[2, 1], [1, 2], [2, 2]])
c = np.array([[1, 1], [2, 2], [1, 1]])
def fn(v1, v2, v3): # v1, v2, v3 are rows of np.array
return v1[0]*v1[1] + v2[0]*v2[1] + v3[0]*v3[1]
tf.map_fn(lambda v1, v2, v3: fn(v1, v2, v3), (a, b, c))
但是,以上代码无法正常工作,如何解决?
我定义了一个函数f3
def f3(a):
return a[0] + a[1]
tf.map_fn(f3, a) # there is no error
答案 0 :(得分:1)
使用tf.stack()
可以解决我的问题
tf.map_fn(lambda v1, v2, v3: fn(v1, v2, v3), tf.stack([a, b, c], 1))