我正在尝试实现一种特殊的矩阵乘法,其中门控矩阵乘以权重矩阵。我打算实现的是:
# x.shape = [70*20*1024,1]
# g.shape = [70*20*96*1*1]
# w.shape = [96*128*128]
g = g * weight0 # Cuda out of memory
g = g.view(70,20,1536,1024)
res = g@X
但是问题是当我要将门(g)与权重矩阵(weight0)相乘时,它抛出Cuda out of memory
异常,这是由于Hadamard产品(g*weight0
)中的广播所致。我该如何解决这个问题?我的意思是,如何实现Hadamard产品以使用更少的内存?还是Pytorch中有任何功能可以用更少的内存来实现相同的目的?