Pytorch:输出w.r.t参数的梯度

时间:2018-05-04 13:06:08

标签: python neural-network pytorch gradient-descent autograd

我对寻找关于参数(权重和偏差)的神经网络输出梯度感兴趣。

更具体地说,假设我有以下神经网络结构[6,4,3,1]。输入样本大小为20.我感兴趣的是找到权重(和偏差)的神经网络输出的梯度,如果我没有弄错,在这种情况下将是47。文献中,这种梯度有时被称为Weight_Jacobian。

我在Jupyter Notebook的Python 3.6上使用Pytorch 0.4.0版。

我制作的代码是:

def init_params(layer_sizes, scale=0.1, rs=npr.RandomState(0)):
    return [(rs.randn(insize, outsize) * scale,   # weight matrix
                 rs.randn(outsize) * scale)           # bias vector
                 for insize, outsize in 
                 zip(layer_sizes[:-1],layer_sizes[1:])]
layers = [6, 4, 3, 1]
w = init_params(layers)
first_layer_w = Variable(torch.tensor(w[0][0],requires_grad=True))
first_layer_bias = Variable(torch.tensor(w[0][1],requires_grad=True))
second_layer_w = Variable(torch.tensor(w[1][0],requires_grad=True))
second_layer_bias = Variable(torch.tensor(w[1][1],requires_grad=True))
third_layer_w = Variable(torch.tensor(w[2][0],requires_grad=True))
third_layer_bias = Variable(torch.tensor(w[2][1],requires_grad=True))
X = Variable(torch.tensor(X_batch),requires_grad=True)
output=torch.tanh(torch.mm(torch.tanh(torch.mm(torch.tanh(torch.mm(X,first_layer_w)+first_layer_bias),second_layer_w)+second_layer_bias),third_layer_w)+third_layer_bias)
output.backward()

从代码中可以明显看出,我使用双曲正切作为非线性。代码生成长度为20的输出向量。现在,我有兴趣找到所有权重(所有47个)的此输出向量的梯度w.r.t。我在here阅读了Pytorch的文档。我也看过类似的问题,例如here。但是,我未能找到输出向量w.r.t参数的梯度。 如果我使用Pytorch函数backward(),它会生成错误

RuntimeError: grad can be implicitly created only for scalar outputs

我的问题是,有没有办法计算输出矢量wrt参数的梯度,它基本上可以表示为20 * 47矩阵,因为我的输出矢量大小为20,参数矢量的大小为47?如果是这样,怎么样?我的代码有什么问题吗?您可以采用X的任何示例,只要其尺寸为20 * 6即可。

2 个答案:

答案 0 :(得分:0)

您正在尝试计算函数的雅可比行列式,而PyTorch希望您能够计算矢量雅可比乘积。您可以看到有关使用PyTorch here计算Jacobian的深入讨论。

您有两个选择。您的第一个选择是使用JAXautograd并使用jacobian()函数。您的第二个选择是坚持使用Pytorch并通过调用backwards(vec) 20次来计算20个向量-jacobian产品,其中vec是长度为20的单热点向量,其中组件的索引为1范围从0到19。如果这令人困惑,建议您阅读JAX教程中的autodiff cookbook

答案 1 :(得分:0)

函数相对于其参数的偏导数矩阵称为 Jacobian,可以在 PyTorch 中计算:

torch.autograd.functional.jacobian(func, inputs)