参数的光谱范数如何计算?

时间:2019-12-01 07:46:00

标签: python pytorch

我这样做的时候

import torch, torch.nn as nn
x = nn.Linear(3, 3)
y = torch.nn.utils.spectral_norm(x)

然后给出四个不同的权重矩阵,

y.weight_u

tensor([ 0.6534, -0.1644,  0.7390])

y.weight_orig

Parameter containing:
tensor([[ 0.2538,  0.3196,  0.3380],
        [ 0.4946,  0.0519,  0.1022],
        [-0.5549, -0.0401,  0.1654]], requires_grad=True)

y.weight_v

tensor([-0.3650,  0.2870,  0.8857])

y.weight

tensor([[ 0.5556,  0.6997,  0.7399],
        [ 1.0827,  0.1137,  0.2237],
        [-1.2149, -0.0878,  0.3622]], grad_fn=<DivBackward0>)

这四个矩阵如何计算?

1 个答案:

答案 0 :(得分:1)

我刚读完有关此方法的论文,可以在on arxiv中找到。如果您具有适当的数学背景,建议您阅读。有关幂算法的描述,请参阅附录A。

那是我在这里尝试总结的地方。

首先,您应该知道矩阵的谱范数是最大奇异值。作者建议找到权重矩阵1.2的频谱范数,然后将b=Math.floor($(this).val()/5;)除以其频谱范数使其接近W(此决策的合理性在本文中)。

虽然我们仅可以使用W来找到奇异值的精确估计,但它们却使用一种称为“幂迭代”的快速(但不精确)的方法。长话短说,1torch.svd是对应于W的最大奇异值的左右奇异向量的粗略近似。它们之所以有用,是因为W的相关奇异值(即谱范数)如果weight_uweight_vu.transpose(1,0) @ W @ v的实际左/右奇异向量,则W等于u

  • v包含图层中的原始值。
  • Wy.weight_orig的第一个左奇异矢量的近似值。
  • y.weight_uy.weight_orig的第一个右奇异矢量的近似值。
  • y.weight_v是更新的权重矩阵,它是y.weight_orig除以其近似频谱范数。

我们可以通过显示实际的左奇异矢量和右奇异矢量几乎平行于y.weighty.weight_orig

来验证这些主张。
y.weight_u

结果

y.weight_v

增加import torch import torch.nn as nn # pytorch default is 1 n_power_iterations = 1 y = nn.Linear(3,3) y = nn.utils.spectral_norm(y, n_power_iterations=n_power_iterations) # spectral normalization is performed during forward pre hook for technical reasons, we # need to send something through the layer to ensure normalization is applied # NOTE: After this is performed, x.weight is changed in place! _ = y(torch.randn(1,3)) # test svd vs. spectral_norm u/v estimates u,s,v = torch.svd(y.weight_orig) cos_err_u = 1.0 - torch.abs(torch.dot(y.weight_u, u[:, 0])).item() cos_err_v = 1.0 - torch.abs(torch.dot(y.weight_v, v[:, 0])).item() print('u-estimate cosine error:', cos_err_u) print('v-estimate cosine error:', cos_err_v) # singular values actual_orig_sn = s[0].item() approx_orig_sn = (y.weight_u @ y.weight_orig @ y.weight_v).item() print('Actual original spectral norm:', actual_orig_sn) print('Approximate original spectral norm:', approx_orig_sn) # updated weights singular values u,s_new,v = torch.svd(y.weight.data, compute_uv=False) actual_sn = s_new[0].item() print('Actual updated spectral norm:', actual_sn) print('Desired updated spectral norm: 1.0') 参数会增加估计的准确性,但会花费计算时间。