tf.map_fn如何工作?

时间:2017-08-27 13:44:24

标签: tensorflow

看看演示:

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]

我无法理解第二和第三。

第二个: 我认为结果是[2,-1],因为第一次x = np.array([1,2,3])并返回1 * 2,第二次x = np.array([ - 1,1] ,-1])并返回1 *( - 1)

第三个: 我认为结果的形状是(3,2),因为第一次x = 1并返回(1,-1),第二次x = 2并返回(2,-2),第三次x = 3并返回(3,-3)。

那么map_fn如何运作?

2 个答案:

答案 0 :(得分:1)

Tensorflow map_fn,来自文档,

  

映射在尺寸0上从elems解包的张量列表。

在这种情况下,输入张量的唯一轴[1,2,3],或[-1,1,-1]。因此操作是1 * -1,2 * 1和3 * -1,并且重新打包结果,给出张量形状。

答案 1 :(得分:1)

  

对于第二次:我认为结果是[2,-1],因为第一次   x = np.array([1,2,3])并返回1 * 2,第二次x = np.array([ - 1,   1,-1])并返回1 *( - 1)

In [26]: a = np.array([[1, 2, 3], [2, 4, 1], [5, 1, 7]])
In [27]: b = np.array([[1, -1, -1], [1, 1, 1], [-1, 1, -1]])
In [28]: elems = (a, b)    
In [29]: alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
In [30]: alternate.eval()
Out[30]: 
array([[ 1, -2, -3],
       [ 2,  4,  1],
       [-5,  1, -7]])

您将看到它是应用于函数的elems中每个元素的0维度的张量。

  

对于第三种:我认为结果的形状是(3,2),因为第一种   时间x = 1并返回(1,-1),第二次x = 2并返回(2,-2),   第三次x = 3并返回(3,-3)。

In [36]: elems = np.array([[1, 2, 3], [4, 5, 1], [1, 6, 1]])
In [37]: alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
In [38]: alternates
Out[38]: 
(<tf.Tensor 'map_6/TensorArrayStack/TensorArrayGatherV3:0' shape=(3, 3) dtype=int64>,
 <tf.Tensor 'map_6/TensorArrayStack_1/TensorArrayGatherV3:0' shape=(3, 3) dtype=int64>)
In [39]: alternates[0].eval()
Out[39]: 
array([[1, 2, 3],
       [4, 5, 1],
       [1, 6, 1]])
In [40]: alternates[1].eval()
Out[40]: 
array([[-1, -2, -3],
       [-4, -5, -1],
       [-1, -6, -1]])

获得您期望的结果:

In [8]: elems = np.array([[1], [2], [3]])                                                          
In [9]: alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
In [10]: sess = tf.InteractiveSession()                                                            
In [11]: alternates[0].eval()
Out[11]: 
array([[1],
       [2],
       [3]])

In [12]: alternates[1].eval()                                                                      
Out[12]: 
array([[-1],
       [-2],
       [-3]])

这可以帮助您更好地理解map_fn。