来自docs:
转置
a
。根据烫发置换尺寸。返回张量的维度i将对应于输入 维度
perm[i]
。如果未给出perm
,则将其设置为(n-1 ... 0),其中 n是输入张量的等级。因此默认情况下,此操作 在二维输入张量上执行常规矩阵转置。
但是我仍然有点不清楚我应该如何切割输入张量。例如。来自文档:
tf.transpose(x, perm=[0, 2, 1]) ==> [[[1 4]
[2 5]
[3 6]]
[[7 10]
[8 11]
[9 12]]]
为什么perm=[0,2,1]
产生1x3x2张量?
经过一些试验和错误:
twothreefour = np.array([ [[1,2,3,4], [5,6,7,8], [9,10,11,12]] ,
[[13,14,15,16], [17,18,19,20], [21,22,23,24]] ])
twothreefour
[OUT]:
array([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
如果我转置它:
fourthreetwo = tf.transpose(twothreefour)
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print (fourthreetwo.eval())
我得到4x3x2到2x3x4,这听起来合乎逻辑。
[OUT]:
[[[ 1 13]
[ 5 17]
[ 9 21]]
[[ 2 14]
[ 6 18]
[10 22]]
[[ 3 15]
[ 7 19]
[11 23]]
[[ 4 16]
[ 8 20]
[12 24]]]
但是当我使用perm
参数输出时,我不确定我真正得到的是什么:
twofourthree = tf.transpose(twothreefour, perm=[0,2,1])
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print (threetwofour.eval())
[OUT]:
[[[ 1 5 9]
[ 2 6 10]
[ 3 7 11]
[ 4 8 12]]
[[13 17 21]
[14 18 22]
[15 19 23]
[16 20 24]]]
为什么perm=[0,2,1]
会从2x3x4返回2x4x3矩阵?
使用perm=[1,0,2]
再次尝试:
threetwofour = tf.transpose(twothreefour, perm=[1,0,2])
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
print (threetwofour.eval())
[OUT]:
[[[ 1 2 3 4]
[13 14 15 16]]
[[ 5 6 7 8]
[17 18 19 20]]
[[ 9 10 11 12]
[21 22 23 24]]]
为什么perm=[1,0,2]
会从2x3x4返回3x2x4?
这是否意味着perm
参数正在使用我的np.shape
并根据基于数组形状的元素转置张量?
即。 :
_size = (2, 4, 3, 5)
randarray = np.random.randint(5, size=_size)
shape_idx = {i:_s for i, _s in enumerate(_size)}
randarray_t_func = tf.transpose(randarray, perm=[3,0,2,1])
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
tranposed_array = randarray_t_func.eval()
print (tranposed_array.shape)
print (tuple(shape_idx[_s] for _s in [3,0,2,1]))
[OUT]:
(5, 2, 3, 4)
(5, 2, 3, 4)
答案 0 :(得分:28)
我认为perm
正在排列维度。例如,perm=[0,2,1]
是dim_0 -> dim_0, dim_1 -> dim_2, dim_2 -> dim_1
的缩写。因此对于2D张量,perm=[1,0]
只是矩阵转置。这是否回答了你的问题?
答案 1 :(得分:2)
A=[2,3,4] matrix, using perm(1,0,2) will get B=[3,2,4].
说明:
Index=(0,1,2)
A =[2,3,4]
Perm =(1,0,2)
B =(3,2,4) --> Perm 1 from Index 1 (3), Perm 0 from Index 0 (2), Perm 2 from Index 2 (4) --> so get (3,2,4)