我正在处理一个矩阵sim_mat,它是一个大小的张量(?,?,?)。
我想像这样使用tf.slice:
a=tf.slice(sim_matrix,[0,0,0],tf.stack([tf.shape(sim_matrix[0],tf.shape(sim_matrix)[1],3]))
print(a)
但是这给了我张量大小(?,?,?)而不是(?,?,3)。此外,当我使用a进一步的功能时,我收到一条错误消息:
输入大小(输入深度)必须可通过形状推理访问,但锯值无。
有解决方法吗?
感谢。
答案 0 :(得分:2)
您可以使用精美的索引编写以下等效但更具可读性的行:
a = sim_matrix[ :, :, 0 : 3 ]
然而,这仍然会给你一个形状的张量(?,?,?)。
如果您确实需要修复生成的张量的第三维,请使用tf.gather()
。但是,您必须使用tf.transpose()
之前和之后对维度进行置换,因为tf.gather()
仅适用于第一维。 (也有tf.gather_nd()
,但这里没有帮助,因为你不想在前两个dims上指定索引,你想要切片。)所以这样(测试代码) ):
import tensorflow as tf
sim_matrix = tf.placeholder( shape=(None,None,None), dtype = tf.float32 )
sim_matrixT = tf.transpose( sim_matrix, [ 2, 0, 1 ] )
aT = tf.gather( sim_matrixT, [ 0, 1, 2 ] ) # get first three
a = tf.transpose( aT, [ 1, 2, 0 ] )
print( a.get_shape() )
输出:
(?,?,3)
您还说您在进一步的功能上会出错。如果这些函数在输入张量上需要固定的形状,情况仍然如此。