使用复杂索引对损失函数进行向量化

时间:2019-09-07 19:32:54

标签: python vectorization pytorch

我已经在pytorch中编写了损失函数,但是它太慢了。 因此,我正在尝试将其向量化。

说明

给出两个网格,它计算在它们上定义的面片(即一组顶点)之间的变形。失真是指每对点之间的距离变化了多少。

enter image description here

  • 补丁的大小可能不同
  • 这假定M和N之间有对应关系。
  • 接受损失的批次。

非矢量化代码

def loss_func(y_pred, batch):
    y_true = batch['vertices']
    patches: Tuple(Tensor,...) = batch['patches']

    loss = 0
    for i in range(y_pred.shape[0]):
        patch_loss = 0
        for patch in patches:
            dist_a = torch.pdist(y_pred[i][patch])
            dist_b = torch.pdist(y_true[i][patch])

            patch_loss += torch.sum((dist_a - dist_b) ** 2)

        loss += patch_loss
    return loss / y_pred.shape[0]

向量化尝试

def loss_func(y_pred, batch):
    y_true = batch['vertices']
    patches: Tuple(Tensor,...) = batch['patches']

    pred_dists = torch.cdist(y_pred, y_pred)
    true_dists = torch.cdist(y_true, y_true)

    diff_dists = (pred_dists - true_dists) ** 2

    loss = 0
    for patch in patches:
        d = diff_dists[:, patch, :][:, :, patch]
        loss += torch.sum(d, dim=(1, 2)) / 2

    return loss.mean()

问题

它仍然太慢。 我在矢量化过程中遇到了困难,因为补丁的尺寸可能不同,因此无法将它们放在张量中。

我觉得我可以直接使用邻接矩阵对代码进行矢量化处理,而无需在补丁中收集顶点索引。

补丁的计算方法如下:

def get_patch(vertex_index, vertex2vertex):
    return vertex2vertex[vertex_index, :].nonzero().squeeze(-1)

工作示例

我准备了一个独立的示例,使用一些基准测试和试验这两个功能。 在基准测试中,有相同的损失函数,没有补丁(即只是所有顶点之间的失真)作为参考时间

import torch
import time

device = 'cuda'
precision = torch.double
batch_size = 8
num_vertices = 400
dim = 2

y_pred = torch.rand(batch_size, num_vertices, dim, device=device, dtype=precision)
y_true = torch.rand(batch_size, num_vertices, dim, device=device, dtype=precision)

adjacency = (torch.rand(num_vertices, num_vertices, device=device, dtype=precision) > 0.95).to(precision)

def get_patch(vertex_index, vertex2vertex):
    return vertex2vertex[vertex_index, :].nonzero().squeeze(-1)

patches = tuple(get_patch(i, adjacency) for i in range(num_vertices))
batch = {
    'vertices': y_true,
    'patches': patches
}

def loss_func(y_pred, batch):
    y_true = batch['vertices']
    patches: Tuple(Tensor,...) = batch['patches']

    loss = 0
    for i in range(y_pred.shape[0]):
        patch_loss = 0
        for patch in patches:
            dist_a = torch.pdist(y_pred[i][patch])
            dist_b = torch.pdist(y_true[i][patch])

            patch_loss += torch.sum((dist_a - dist_b) ** 2)

        loss += patch_loss
    return loss / y_pred.shape[0]

def loss_func_vec(y_pred, batch):
    y_true = batch['vertices']
    patches: Tuple(Tensor,...) = batch['patches']

    pred_dists = torch.cdist(y_pred, y_pred)
    true_dists = torch.cdist(y_true, y_true)

    diff_dists = (pred_dists - true_dists) ** 2

    loss = 0
    for patch in patches:
        d = diff_dists[:, patch, :][:, :, patch]
        loss += torch.sum(d, dim=(1, 2)) / 2

    return loss.mean()

def reference_no_patches(y_pred, batch):
    y_true = batch['vertices']
    da = torch.cdist(y_pred, y_pred)
    db = torch.cdist(y_true, y_true)
    return torch.sum((da - db) ** 2) / y_true.shape[0] / 2

print(f'LOSS FUNCTION (ITERATION): TIME (LOSS)')
print()

for i in range(5):
    start_time = time.time()
    res = loss_func(y_pred, batch)
    torch.cuda.synchronize()
    print(f'Standard {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
print()

for i in range(5):
    start_time = time.time()
    res = loss_func_vec(y_pred, batch)
    torch.cuda.synchronize()
    print(f'Vectorized {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
print()

for i in range(5):
    start_time = time.time()
    res = reference_no_patches(y_pred, batch)
    torch.cuda.synchronize()
    print(f'Reference without patches {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
LOSS FUNCTION (ITERATION): TIME (LOSS)

Standard 0: 0.3311s (9744.87875)
Standard 1: 0.2853s (9744.87875)
Standard 2: 0.2929s (9744.87875)
Standard 3: 0.2972s (9744.87875)
Standard 4: 0.2714s (9744.87875)

Vectorized 0:   0.0254s (9744.87875)
Vectorized 1:   0.0251s (9744.87875)
Vectorized 2:   0.0262s (9744.87875)
Vectorized 3:   0.0242s (9744.87875)
Vectorized 4:   0.0241s (9744.87875)

Reference without patches 0:    0.0063s (9695.41083)
Reference without patches 1:    0.0062s (9695.41083)
Reference without patches 2:    0.0062s (9695.41083)
Reference without patches 3:    0.0066s (9695.41083)
Reference without patches 4:    0.0067s (9695.41083)

上下文

涉及的矩阵很小。每个网格中的顶点数约为500,因此邻接矩阵为[500, 500]。 补丁的大小可能有所不同,但是每个补丁的大小都小于<<500。每个顶点有一个补丁(即每个顶点是补丁的中心),因此补丁的数目为O(n!)(因此,我可以不会将补丁堆叠到单个向量中以执行索引操作。

我在示例中使用双精度只是为了避免结果中出现数值错误,我打算使用浮点精度。

最后,我将很乐意收到有关提高速度的任何建议!

更新

由于这个bug,我必须注意内存分配。 我将向量化代码重写为:

def loss_func_vec_mem(y_pred, batch):
    y_true = batch['vertices']
    patches: Tuple(Tensor,...) = batch['patches']

    loss = 0
    for patch in patches:
        pred = y_pred[:, patch, :]
        true = y_true[:, patch, :]

        pred_dists = torch.cdist(pred, pred)
        true_dists = torch.cdist(true, true)

        diff_dists = (pred_dists - true_dists) ** 2

        loss += torch.sum(diff_dists, dim=(1, 2)) 
    return loss.mean() / 2

尽管它甚至更慢:

LOSS FUNCTION (ITERATION): TIME (LOSS)

Vec mem 0:  0.0484s (14685.45521)
Vec mem 1:  0.0489s (14685.45521)
Vec mem 2:  0.0471s (14685.45521)
Vec mem 3:  0.0471s (14685.45521)
Vec mem 4:  0.0459s (14685.45521)

Vectorized 0:   0.0271s (29370.91043)
Vectorized 1:   0.0262s (29370.91043)
Vectorized 2:   0.0320s (29370.91043)
Vectorized 3:   0.0266s (29370.91043)
Vectorized 4:   0.0263s (29370.91043)

Standard 0: 0.1680s (14685.45521)
Standard 1: 0.1614s (14685.45521)
Standard 2: 0.1639s (14685.45521)
Standard 3: 0.1606s (14685.45521)
Standard 4: 0.1694s (14685.45521)

Reference without patches 0:    0.0049s (12874.01275)
Reference without patches 1:    0.0105s (12874.01275)
Reference without patches 2:    0.0103s (12874.01275)
Reference without patches 3:    0.0092s (12874.01275)
Reference without patches 4:    0.0100s (12874.01275)

0 个答案:

没有答案