pytorch批处理index_select没有循环怎么办?

时间:2019-11-01 07:00:17

标签: pytorch

我有一个点云数据集。我想为形状上的每个点计算KNN点,然后选择它们。我想我可以先计算距离矩阵,然后使用pytorch api torch.topk来获取k个最大的索引和值。但是,根据pytorch文档,我们只能提供一维张量,其中包含要索引到api torch.index_select的索引。这意味着我必须使用for循环来选择这些点。有更有效的方法吗?

以下两个是我用来计算一个形状中每个点的PCA的函数。

def get_distance_matrix(p1,p2):
    p1 = p1.unsqueeze(1)
    p2 = p2.unsqueeze(1)

    p1 = p1.repeat(1, p2.size(2), 1, 1)
    p1 = p1.transpose(1, 2)

    p2 = p2.repeat(1, p1.size(1), 1, 1)

    dist = torch.add(p1, torch.neg(p2))
    dist = torch.norm(dist, 2, dim = 3)

    return dist

NUM_LOCAL_POINTS = 16

# estimate each (index th) point normal by using pca 
def compute_pca(points, distance_matrix, point_index, batch_index):
    distance_vector = distance_matrix[batch_index, point_index] # get the index row from distance matrix
    point_vector = points[batch_index, point_index] # get the point from point clouds

    values, indices = torch.topk(1.0/distance_vector, NUM_LOCAL_POINTS, dim=0) # select knn points

    # create the correct indices shape to slice the points
    local_patch = torch.index_select(points[batch_index], 0, indices) - point_vector

    # compute the pca and the return eigen_value is sotred in non-descreaing order
    local_patch_T = torch.transpose(local_patch, 1, 0) 
    eigen_values, eigen_vectors = torch.eig(torch.matmul(local_patch_T, local_patch), True)

    # find the smallest eigen_value's eigen_vector
    # print("Size of point vector is ", point_vector.size())
    min_index = torch.argmin(eigen_values[:, 0], 0)

    return eigen_vectors[:, min_index]# return the smallest eigen_vectors

0 个答案:

没有答案