我正在使用4维张量,并且需要进行一些计算,如以下示例所示。将A
设为形状为(6,64,64,64)
的张量。我想使用函数tf.where
来获取值大于(64,64,64)
的每个0.75
卷的体素。我设法做到这一点的唯一方法是这样的:
X = tf.convert_to_tensor([tf.where(A[i,:,:,:] > 0.75) for i in range(A.shape[0])]
这似乎是非常粗糙的解决方案。有没有更好的方法可以做到这一点?
答案 0 :(得分:0)
您要尝试执行的操作的问题在于,它要求每个(64, 64, 64)
卷具有相同数量的大于0.75的值。如果是这样,您可以执行以下操作:
X = tf.reshape(tf.where(A > 0.75)[:, 1:], (A.shape[0], -1, A.shape.ndims - 1))
但是如果不是这样,您就不能拥有这样的张量,因为第二维需要具有多个尺寸。