我有以下代码片段可在Pytorch中进行复杂乘法
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
out = torch.stack([a*c - b*d, a*d + b*c], dim=-1)
return out
当我调查为什么得到RuntimeError: CUDA out of memory
时,
我意识到,内存分配取决于我明确给出的指令数量。
我用三种不同的方式编写了该代码,另两种方式如下。
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
real = a*c - b*d
imag = a*d + b*c
pair = [real, imag]
out = torch.stack(pair, dim=-1)
return out
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
r1 = a*c
r2 = b*d
real = r1 - r2
i1 = a*d
i2 = b*c
imag = i1 + i2
pair = [real, imag]
out = torch.stack(pair, dim=-1)
return out
在Variation 1 Gist中,我们有:
float32
,即4 bytes
_x
的大小为(128, 64, 32, 32, 2)
,因此64 MB
,32 MB
每个组成部分(real, imag)
每个分量乘法(a*c
,b*d
,a*d
和b*c
)将有2048 MB
64
乘以32 MB
的每个组成部分的x
8 GB
小计加2 GB
,如果将每个sum
结果视为存储在其自己的变量中
12GB
小计堆叠real
和imag
需要更多的4 GB
16 GB
RuntimeError: CUDA out of memory. Tried to allocate 4.00 GiB (GPU 0; 14.73 GiB total capacity; 12.12 GiB already allocated; 1.82 GiB free; 12.12 GiB reserved in total by PyTorch)
直觉上,我随后决定通过思考坚持使用Variation 3
- “没有这些临时变量(占位符)的持有者,Torch将知道它们将不再被使用,从而释放了它们的内存。因此,在执行结束之前仅使用
4 GB
。”
但是,我不理解的是,这是我的问题,是
为什么在Variation 3上几次调用该方法后仍然得到
CUDA out of memory
?要点表明,在第三个调用中,它由于没有足够的内存而崩溃。
更有趣的是,尽管人们可以思考
- “那是因为已经有
4 GB
分配了先前的结果。”
但是,考虑到它能够执行两次,我认为这不是问题。