我已经在pytorch中编写了损失函数,但是它太慢了。 因此,我正在尝试将其向量化。
给出两个网格,它计算在它们上定义的面片(即一组顶点)之间的变形。失真是指每对点之间的距离变化了多少。
def loss_func(y_pred, batch):
y_true = batch['vertices']
patches: Tuple(Tensor,...) = batch['patches']
loss = 0
for i in range(y_pred.shape[0]):
patch_loss = 0
for patch in patches:
dist_a = torch.pdist(y_pred[i][patch])
dist_b = torch.pdist(y_true[i][patch])
patch_loss += torch.sum((dist_a - dist_b) ** 2)
loss += patch_loss
return loss / y_pred.shape[0]
def loss_func(y_pred, batch):
y_true = batch['vertices']
patches: Tuple(Tensor,...) = batch['patches']
pred_dists = torch.cdist(y_pred, y_pred)
true_dists = torch.cdist(y_true, y_true)
diff_dists = (pred_dists - true_dists) ** 2
loss = 0
for patch in patches:
d = diff_dists[:, patch, :][:, :, patch]
loss += torch.sum(d, dim=(1, 2)) / 2
return loss.mean()
它仍然太慢。 我在矢量化过程中遇到了困难,因为补丁的尺寸可能不同,因此无法将它们放在张量中。
我觉得我可以直接使用邻接矩阵对代码进行矢量化处理,而无需在补丁中收集顶点索引。
补丁的计算方法如下:
def get_patch(vertex_index, vertex2vertex):
return vertex2vertex[vertex_index, :].nonzero().squeeze(-1)
我准备了一个独立的示例,使用一些基准测试和试验这两个功能。 在基准测试中,有相同的损失函数,没有补丁(即只是所有顶点之间的失真)作为参考时间
import torch
import time
device = 'cuda'
precision = torch.double
batch_size = 8
num_vertices = 400
dim = 2
y_pred = torch.rand(batch_size, num_vertices, dim, device=device, dtype=precision)
y_true = torch.rand(batch_size, num_vertices, dim, device=device, dtype=precision)
adjacency = (torch.rand(num_vertices, num_vertices, device=device, dtype=precision) > 0.95).to(precision)
def get_patch(vertex_index, vertex2vertex):
return vertex2vertex[vertex_index, :].nonzero().squeeze(-1)
patches = tuple(get_patch(i, adjacency) for i in range(num_vertices))
batch = {
'vertices': y_true,
'patches': patches
}
def loss_func(y_pred, batch):
y_true = batch['vertices']
patches: Tuple(Tensor,...) = batch['patches']
loss = 0
for i in range(y_pred.shape[0]):
patch_loss = 0
for patch in patches:
dist_a = torch.pdist(y_pred[i][patch])
dist_b = torch.pdist(y_true[i][patch])
patch_loss += torch.sum((dist_a - dist_b) ** 2)
loss += patch_loss
return loss / y_pred.shape[0]
def loss_func_vec(y_pred, batch):
y_true = batch['vertices']
patches: Tuple(Tensor,...) = batch['patches']
pred_dists = torch.cdist(y_pred, y_pred)
true_dists = torch.cdist(y_true, y_true)
diff_dists = (pred_dists - true_dists) ** 2
loss = 0
for patch in patches:
d = diff_dists[:, patch, :][:, :, patch]
loss += torch.sum(d, dim=(1, 2)) / 2
return loss.mean()
def reference_no_patches(y_pred, batch):
y_true = batch['vertices']
da = torch.cdist(y_pred, y_pred)
db = torch.cdist(y_true, y_true)
return torch.sum((da - db) ** 2) / y_true.shape[0] / 2
print(f'LOSS FUNCTION (ITERATION): TIME (LOSS)')
print()
for i in range(5):
start_time = time.time()
res = loss_func(y_pred, batch)
torch.cuda.synchronize()
print(f'Standard {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
print()
for i in range(5):
start_time = time.time()
res = loss_func_vec(y_pred, batch)
torch.cuda.synchronize()
print(f'Vectorized {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
print()
for i in range(5):
start_time = time.time()
res = reference_no_patches(y_pred, batch)
torch.cuda.synchronize()
print(f'Reference without patches {i}:\t{time.time() - start_time:.4f}s ({res:.5f})')
LOSS FUNCTION (ITERATION): TIME (LOSS)
Standard 0: 0.3311s (9744.87875)
Standard 1: 0.2853s (9744.87875)
Standard 2: 0.2929s (9744.87875)
Standard 3: 0.2972s (9744.87875)
Standard 4: 0.2714s (9744.87875)
Vectorized 0: 0.0254s (9744.87875)
Vectorized 1: 0.0251s (9744.87875)
Vectorized 2: 0.0262s (9744.87875)
Vectorized 3: 0.0242s (9744.87875)
Vectorized 4: 0.0241s (9744.87875)
Reference without patches 0: 0.0063s (9695.41083)
Reference without patches 1: 0.0062s (9695.41083)
Reference without patches 2: 0.0062s (9695.41083)
Reference without patches 3: 0.0066s (9695.41083)
Reference without patches 4: 0.0067s (9695.41083)
涉及的矩阵很小。每个网格中的顶点数约为500,因此邻接矩阵为[500, 500]
。
补丁的大小可能有所不同,但是每个补丁的大小都小于<<500。每个顶点有一个补丁(即每个顶点是补丁的中心),因此补丁的数目为O(n!)(因此,我可以不会将补丁堆叠到单个向量中以执行索引操作。
我在示例中使用双精度只是为了避免结果中出现数值错误,我打算使用浮点精度。
最后,我将很乐意收到有关提高速度的任何建议!
由于这个bug,我必须注意内存分配。 我将向量化代码重写为:
def loss_func_vec_mem(y_pred, batch):
y_true = batch['vertices']
patches: Tuple(Tensor,...) = batch['patches']
loss = 0
for patch in patches:
pred = y_pred[:, patch, :]
true = y_true[:, patch, :]
pred_dists = torch.cdist(pred, pred)
true_dists = torch.cdist(true, true)
diff_dists = (pred_dists - true_dists) ** 2
loss += torch.sum(diff_dists, dim=(1, 2))
return loss.mean() / 2
尽管它甚至更慢:
LOSS FUNCTION (ITERATION): TIME (LOSS)
Vec mem 0: 0.0484s (14685.45521)
Vec mem 1: 0.0489s (14685.45521)
Vec mem 2: 0.0471s (14685.45521)
Vec mem 3: 0.0471s (14685.45521)
Vec mem 4: 0.0459s (14685.45521)
Vectorized 0: 0.0271s (29370.91043)
Vectorized 1: 0.0262s (29370.91043)
Vectorized 2: 0.0320s (29370.91043)
Vectorized 3: 0.0266s (29370.91043)
Vectorized 4: 0.0263s (29370.91043)
Standard 0: 0.1680s (14685.45521)
Standard 1: 0.1614s (14685.45521)
Standard 2: 0.1639s (14685.45521)
Standard 3: 0.1606s (14685.45521)
Standard 4: 0.1694s (14685.45521)
Reference without patches 0: 0.0049s (12874.01275)
Reference without patches 1: 0.0105s (12874.01275)
Reference without patches 2: 0.0103s (12874.01275)
Reference without patches 3: 0.0092s (12874.01275)
Reference without patches 4: 0.0100s (12874.01275)