是否有更好/更快的方法在Pytorch中实现以下代码,同时又避免了循环并仍保持计算图完整无缺?
def cumulative_max(X, dim=-1):
out = X.clone()
if dim < 0:
dim += X.dim()
leading_indices = (slice(None), ) * dim
n_iters = X.size(dim)
for idx in range(1, n_iters):
out[leading_indices + (idx, )] = torch.max(out[leading_indices + (idx - 1, )], X[leading_indices + (idx, )])
return out