如何解释tf.map_fn的结果?

时间:2017-09-07 12:47:03

标签: tensorflow

查看代码:

import tensorflow as tf
import numpy as np

elems = tf.ones([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

输出结果为:

(array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
        [1, 1, 1]]], dtype=int64))

我无法理解输出,谁能告诉我?

更新

elems是一个张量,所以它应该沿着轴0解压缩,然后我们会得到[[1,1,1],[1,1,1]],然后map_fn[[1,1,1],[1,1,1]]传递给lambda x:(x,x,x) },表示x=[[1,1,1],[1,1,1]],我认为map_fn的输出是

[[[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]],
 [[1,1,1],[1,1,1]]]

输出的形状为[3,2,3]shape(2,3)

的列表

但实际上,输出是张量列表,每个张量的形状是[1,2,3]

或换句话说:

import tensorflow as tf
import numpy as np

elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

为什么输出

(array([1, 2, 3], dtype=int64), 
 array([2, 4, 6], dtype=int64), 
 array([-1, -2, -3], dtype=int64))

而不是

(array([1, 2, -1], dtype=int64), 
 array([2, 4, -2], dtype=int64), 
 array([3, 6, -3], dtype=int64))

这两个问题是一样的。

UPDATE2

import tensorflow as tf
import numpy as np

elems = [tf.constant([1,2,3],dtype=tf.int64)]
alternates = tf.map_fn(lambda x: x, elems, dtype=tf.int64)
with tf.Session() as sess:
    print(sess.run(alternates))

elems是张量列表,因此根据api,tf.constant([1,2,3],dtype=tf.int64)将沿着轴0解压缩,因此map_fn将作为[x for x in [1,2,3]]使用,但事实上它会引发错误。

ValueError: The two structures don't have the same nested structure. First struc
ture: <dtype: 'int64'>, second structure: [<tf.Tensor 'map/while/TensorArrayRead
V3:0' shape=() dtype=int64>].

怎么了?

UPDATE3

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

输出

(array([1, 2, 3], dtype=int64), array([1, 2, 3], dtype=int64))

似乎elems没有解压缩,为什么?

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: [x], elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

会引发错误

TypeError: The two structures don't have the same sequence type. First structure
 has type <class 'tuple'>, while second structure has type <class 'list'>.

谁能告诉我tf.map_fn是如何运作的?

2 个答案:

答案 0 :(得分:5)

首先,

elems = tf.ones([1,2,3],dtype=tf.int64)

elems是一个三维张量,形状为1x2x3,其中有1个,即:

[[[1, 1, 1],
  [1, 1, 1]]]

然后,

alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))

alternates是三个张量的元组,其形状与elems相同,每个张量都是根据给定的函数构建的。由于函数只返回一个元组重复其输入三次,这意味着三个张量将与elems相同。如果函数是lambda x: (x, 2 * x, -x),那么第一个输出张量将与elems相同,第二个将是elems的两倍,而第三个则相反。

在所有这些情况下,最好使用常规操作而不是tf.map_fn;但是,在某些情况下,您可能会接受具有 N 维度的张量的函数,并且您希望将 N + 1的张量应用于其中。

更新:

我认为你正在考虑tf.map_fn“反过来”,所以说。张量中的元素或行数与函数中的输出数之间没有一对一的对应关系;实际上,你可以传递一个函数,返回一个元组,你可以根据需要返回元素。

举个例子:

elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))

tf.map_fn首先在第一个轴中拆分elems,即123,并将函数应用于每个轴,得到:

(1, 2, -1)
(2, 4, -2)
(3, 6, -3)

请注意,正如我所说,这些元组中的每一个都可以包含您想要的元素。现在,产生最终输出结果连接在同一位置;所以你得到:

[1, 2, 3]
[2, 4, 6]
[-1, -2, -3]

同样,如果函数生成具有更多元素的元组,则会获得更多输出张量。

更新2:

关于您的新示例:

import tensorflow as tf
import numpy as np

elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
    print(sess.run(alternates))

documentation说:

  

此方法还允许多个元素和fn的输出。如果elems是(可能是嵌套的)列表或张量元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。 fn的签名可以与elems的结构相匹配。也就是说,如果elems是(t1,[t2,t3,[t4,t5]]),那么fn的适当签名是:fn = lambda(t1,[t2,t3,[t4,t5]]):.

此处elems是两个张量的元组,根据需要在第一维中具有相同的大小。 tf.map_fn一次获取每个输入张量的一个元素(所以是两个元素的元组)并将给定的函数应用于它,它应该返回与dtypes中传递的结构相同的结构(元组的元组)两个元素也是);如果你没有给出dtypes,那么预期的输出与输入相同(同样,两个元素的元组,所以在你的情况下dtypes是可选的)。无论如何,它是这样的:

f((1, 1)) -> (1, 1)
f((2, 2)) -> (2, 2)
f((3, 3)) -> (3, 3)

将这些结果组合起来,连接结构中的所有相应元素;在这种情况下,第一个位置的所有数字产生第一个输出,第二个位置的所有数字产生第二个输出。结果是,最终,请求的结构(两元素元组)填充了这些连接:

([1, 2, 3], [1, 2, 3])

答案 1 :(得分:1)

您的输入elems具有形状 (1,2,3),如下所示:

[[[1, 1, 1],
 [1, 1, 1]]]

包含值1,2,3的矩阵,因为您使用tf.ones()创建了一个填充1的张量,其中您传递的形状为参数

回复更新:

map_fn已应用于elems。 根据{{​​3}}:

  

elems:一个张量或(可能是嵌套的)张量序列,每个张量都将沿着它们的第一个维度解包。生成的切片的嵌套序列将应用于fn。

根据我的理解,该函数需要一个张量或张量列表,并且应该将其切片并将函数应用于每个元素。但是,从结果看来,如果传递一个张量,它是直接应用函数的元素,所以x在调用lambda函数时具有形状(1,2,3)。 然后,该函数会创建一个元组,其中包含3个(1,2,3)矩阵副本(输出中为array(...)

重构输出行并添加缩进以使其更清晰,输出如下所示:

( 
   array( # first copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
   array( # second copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
   array( # third copy of `x`
       [
           [
               [1, 1, 1],
               [1, 1, 1]
           ]
       ], dtype=int64
   ), 
) # end of the tuple

更新2:

我怀疑你遇到了一个错误。如果将elems定义为列表,则会出现错误,但如果将其定义为tuple elems = (tf.constant([1,2,3],dtype=tf.int64)),则代码将按预期工作。对元组和列表的不同处理非常可疑......这就是为什么我认为它是一个错误。 正如@mrry指出的那样,在我的元组中我错过了一个逗号(因此elems是张量本身,而不是包含张量的元组)。