TensorFlow,使用map_fn迭代每一行并执行计算任务

时间:2019-04-06 06:41:33

标签: tensorflow

我定义了一个函数fn来计算每一行的结果,并且我定义了如下代码:

import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
a = np.array([[1, 2], [2, 1], [1, 1]])
b = np.array([[2, 1], [1, 2], [2, 2]])
c = np.array([[1, 1], [2, 2], [1, 1]])

def fn(v1, v2, v3):  # v1, v2, v3 are rows of np.array
    return v1[0]*v1[1] + v2[0]*v2[1] + v3[0]*v3[1]

tf.map_fn(lambda v1, v2, v3: fn(v1, v2, v3), (a, b, c))

但是,以上代码无法正常工作,如何解决?

我定义了一个函数f3

def f3(a):
    return a[0] + a[1]

tf.map_fn(f3, a)  # there is no error

1 个答案:

答案 0 :(得分:1)

使用tf.stack()可以解决我的问题

tf.map_fn(lambda v1, v2, v3: fn(v1, v2, v3), tf.stack([a, b, c], 1))