Pytorch中优化的门控矩阵乘法

时间:2019-07-08 11:59:06

标签: matrix pytorch matrix-multiplication

我正在尝试实现一种特殊的矩阵乘法,其中门控矩阵乘以权重矩阵。我打算实现的是:

# x.shape = [70*20*1024,1]
# g.shape = [70*20*96*1*1]
# w.shape = [96*128*128]
g = g * weight0 # Cuda out of memory
g = g.view(70,20,1536,1024)
res = g@X

但是问题是当我要将门(g)与权重矩阵(weight0)相乘时,它抛出Cuda out of memory异常,这是由于Hadamard产品(g*weight0)中的广播所致。我该如何解决这个问题?我的意思是,如何实现Hadamard产品以使用更少的内存?还是Pytorch中有任何功能可以用更少的内存来实现相同的目的?

0 个答案:

没有答案