如何在未知大小的张量上使用tf.slice

时间:2018-04-16 13:03:30

标签: python tensorflow

我正在处理一个矩阵sim_mat,它是一个大小的张量(?,?,?)。

我想像这样使用tf.slice:

a=tf.slice(sim_matrix,[0,0,0],tf.stack([tf.shape(sim_matrix[0],tf.shape(sim_matrix)[1],3]))
print(a)

但是这给了我张量大小(?,?,?)而不是(?,?,3)。此外,当我使用a进一步的功能时,我收到一条错误消息:

输入大小(输入深度)必须可通过形状推理访问,但锯值无。

有解决方法吗?

感谢。

1 个答案:

答案 0 :(得分:2)

您可以使用精美的索引编写以下等效但更具可读性的行:

a = sim_matrix[ :, :, 0 : 3 ]

然而,这仍然会给你一个形状的张量(?,?,?)。

如果您确实需要修复生成的张量的第三维,请使用tf.gather()。但是,您必须使用tf.transpose()之前和之后对维度进行置换,因为tf.gather()仅适用于第一维。 (也有tf.gather_nd(),但这里没有帮助,因为你不想在前两个dims上指定索引,你想要切片。)所以这样(测试代码) ):

import tensorflow as tf

sim_matrix = tf.placeholder( shape=(None,None,None), dtype = tf.float32 )
sim_matrixT = tf.transpose( sim_matrix, [ 2, 0, 1 ] )
aT = tf.gather( sim_matrixT, [ 0, 1, 2 ] ) # get first three
a = tf.transpose( aT, [ 1, 2, 0 ] )

print( a.get_shape() )

输出:

  

(?,?,3)

您还说您在进一步的功能上会出错。如果这些函数在输入张量上需要固定的形状,情况仍然如此。