查看代码:
import tensorflow as tf
import numpy as np
elems = tf.ones([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
print(sess.run(alternates))
输出结果为:
(array([[[1, 1, 1],
[1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
[1, 1, 1]]], dtype=int64), array([[[1, 1, 1],
[1, 1, 1]]], dtype=int64))
我无法理解输出,谁能告诉我?
elems
是一个张量,所以它应该沿着轴0解压缩,然后我们会得到[[1,1,1],[1,1,1]]
,然后map_fn
将[[1,1,1],[1,1,1]]
传递给lambda x:(x,x,x)
},表示x=[[1,1,1],[1,1,1]]
,我认为map_fn
的输出是
[[[1,1,1],[1,1,1]],
[[1,1,1],[1,1,1]],
[[1,1,1],[1,1,1]]]
输出的形状为[3,2,3]
或shape(2,3)
但实际上,输出是张量列表,每个张量的形状是[1,2,3]
。
或换句话说:
import tensorflow as tf
import numpy as np
elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
with tf.Session() as sess:
print(sess.run(alternates))
为什么输出
(array([1, 2, 3], dtype=int64),
array([2, 4, 6], dtype=int64),
array([-1, -2, -3], dtype=int64))
而不是
(array([1, 2, -1], dtype=int64),
array([2, 4, -2], dtype=int64),
array([3, 6, -3], dtype=int64))
这两个问题是一样的。
import tensorflow as tf
import numpy as np
elems = [tf.constant([1,2,3],dtype=tf.int64)]
alternates = tf.map_fn(lambda x: x, elems, dtype=tf.int64)
with tf.Session() as sess:
print(sess.run(alternates))
elems
是张量列表,因此根据api,tf.constant([1,2,3],dtype=tf.int64)
将沿着轴0解压缩,因此map_fn
将作为[x for x in [1,2,3]]
使用,但事实上它会引发错误。
ValueError: The two structures don't have the same nested structure. First struc
ture: <dtype: 'int64'>, second structure: [<tf.Tensor 'map/while/TensorArrayRead
V3:0' shape=() dtype=int64>].
怎么了?
import tensorflow as tf
import numpy as np
elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
print(sess.run(alternates))
输出
(array([1, 2, 3], dtype=int64), array([1, 2, 3], dtype=int64))
似乎elems没有解压缩,为什么?
import tensorflow as tf
import numpy as np
elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: [x], elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
print(sess.run(alternates))
会引发错误
TypeError: The two structures don't have the same sequence type. First structure
has type <class 'tuple'>, while second structure has type <class 'list'>.
谁能告诉我tf.map_fn是如何运作的?
答案 0 :(得分:5)
首先,
elems = tf.ones([1,2,3],dtype=tf.int64)
elems
是一个三维张量,形状为1x2x3,其中有1个,即:
[[[1, 1, 1],
[1, 1, 1]]]
然后,
alternates = tf.map_fn(lambda x: (x, x, x), elems, dtype=(tf.int64, tf.int64, tf.int64))
alternates
是三个张量的元组,其形状与elems
相同,每个张量都是根据给定的函数构建的。由于函数只返回一个元组重复其输入三次,这意味着三个张量将与elems
相同。如果函数是lambda x: (x, 2 * x, -x)
,那么第一个输出张量将与elems
相同,第二个将是elems
的两倍,而第三个则相反。
在所有这些情况下,最好使用常规操作而不是tf.map_fn
;但是,在某些情况下,您可能会接受具有 N 维度的张量的函数,并且您希望将 N + 1的张量应用于其中。
更新:
我认为你正在考虑tf.map_fn
“反过来”,所以说。张量中的元素或行数与函数中的输出数之间没有一对一的对应关系;实际上,你可以传递一个函数,返回一个元组,你可以根据需要返回元素。
举个例子:
elems = tf.constant([1,2,3],dtype=tf.int64)
alternates = tf.map_fn(lambda x: (x, 2*x, -x), elems, dtype=(tf.int64, tf.int64, tf.int64))
tf.map_fn
首先在第一个轴中拆分elems
,即1
,2
和3
,并将函数应用于每个轴,得到:
(1, 2, -1)
(2, 4, -2)
(3, 6, -3)
请注意,正如我所说,这些元组中的每一个都可以包含您想要的元素。现在,产生最终输出结果连接在同一位置;所以你得到:
[1, 2, 3]
[2, 4, 6]
[-1, -2, -3]
同样,如果函数生成具有更多元素的元组,则会获得更多输出张量。
更新2:
关于您的新示例:
import tensorflow as tf
import numpy as np
elems = (tf.constant([1,2,3],dtype=tf.int64),tf.constant([1,2,3],dtype=tf.int64))
alternates = tf.map_fn(lambda x: x, elems, dtype=(tf.int64, tf.int64))
with tf.Session() as sess:
print(sess.run(alternates))
此方法还允许多个元素和fn的输出。如果elems是(可能是嵌套的)列表或张量元组,那么这些张量中的每一个都必须具有匹配的第一个(解包)维度。 fn的签名可以与elems的结构相匹配。也就是说,如果elems是(t1,[t2,t3,[t4,t5]]),那么fn的适当签名是:fn = lambda(t1,[t2,t3,[t4,t5]]):.
此处elems
是两个张量的元组,根据需要在第一维中具有相同的大小。 tf.map_fn
一次获取每个输入张量的一个元素(所以是两个元素的元组)并将给定的函数应用于它,它应该返回与dtypes
中传递的结构相同的结构(元组的元组)两个元素也是);如果你没有给出dtypes
,那么预期的输出与输入相同(同样,两个元素的元组,所以在你的情况下dtypes
是可选的)。无论如何,它是这样的:
f((1, 1)) -> (1, 1)
f((2, 2)) -> (2, 2)
f((3, 3)) -> (3, 3)
将这些结果组合起来,连接结构中的所有相应元素;在这种情况下,第一个位置的所有数字产生第一个输出,第二个位置的所有数字产生第二个输出。结果是,最终,请求的结构(两元素元组)填充了这些连接:
([1, 2, 3], [1, 2, 3])
答案 1 :(得分:1)
您的输入elems
具有形状 (1,2,3)
,如下所示:
[[[1, 1, 1],
[1, 1, 1]]]
不包含值1,2,3
的矩阵,因为您使用tf.ones()
创建了一个填充1
的张量,其中您传递的形状为参数
map_fn
已应用于elems
。
根据{{3}}:
elems:一个张量或(可能是嵌套的)张量序列,每个张量都将沿着它们的第一个维度解包。生成的切片的嵌套序列将应用于fn。
根据我的理解,该函数需要一个张量或张量列表,并且应该将其切片并将函数应用于每个元素。但是,从结果看来,如果传递一个张量,它是直接应用函数的元素,所以x
在调用lambda函数时具有形状(1,2,3)
。
然后,该函数会创建一个元组,其中包含3个(1,2,3)
矩阵副本(输出中为array(...)
)
重构输出行并添加缩进以使其更清晰,输出如下所示:
(
array( # first copy of `x`
[
[
[1, 1, 1],
[1, 1, 1]
]
], dtype=int64
),
array( # second copy of `x`
[
[
[1, 1, 1],
[1, 1, 1]
]
], dtype=int64
),
array( # third copy of `x`
[
[
[1, 1, 1],
[1, 1, 1]
]
], dtype=int64
),
) # end of the tuple
我怀疑你遇到了一个错误。如果将elems定义为列表,则会出现错误,但如果将其定义为
正如@mrry指出的那样,在我的元组中我错过了一个逗号(因此elems是张量本身,而不是包含张量的元组)。tuple
elems = (tf.constant([1,2,3],dtype=tf.int64))
,则代码将按预期工作。对元组和列表的不同处理非常可疑......这就是为什么我认为它是一个错误。