TensorFlow - 3D张量,从2D张量收集每个第N个张量并且步幅1

时间:2017-12-01 18:31:00

标签: python tensorflow slice reshape

假设2D-Tensor中有T [M, 1],例如

T = tf.expand_dims([A1,
     B1,
     C1,
     A2,
     B2,
     C2], 1)

我想像这样重塑它:

T_reshp = [[[A1], [A2]]
           [[B1], [B2]]
           [[C1], [C2]]]

我事先知道MN(每组中的张量数量)。此外,让我{I}尝试使用t_reshp.shape[0] = M/N = P

tf.reshape
T_reshp = tf.reshape(T, [P, N, 1])

然而,我最终得到了:

T_reshp = [[[A1], [B1]]
           [[C1], [A2]]
           [[B2], [C2]]]

我可以使用一些切片或重塑操作吗?

1 个答案:

答案 0 :(得分:2)

您可以先将其重塑为[N,P,1]尺寸,然后transpose第一和第二轴:

tf.transpose(tf.reshape(T, [N, P, 1]), [1,0,2])
#                           ^^^^ switch the two dimensions here and then transpose

实施例

T = tf.expand_dims([1,2,3,4,5,6], 1)
sess = tf.Session()
T1 = tf.transpose(tf.reshape(T, [2,3,1]), [1,0,2])

sess.run(T1)
#array([[[1],
#        [4]],

#       [[2],
#        [5]],

#       [[3],
#        [6]]], dtype=int32)