ValueError:浮点图像RGB值必须在0..1范围内。使用matplotlib时

时间:2017-11-15 23:09:40

标签: matplotlib deep-learning pytorch

我想要可视化神经网络层的权重。我正在使用pytorch。

import torch
import torchvision.models as models
from matplotlib import pyplot as plt

def plot_kernels(tensor, num_cols=6):
    if not tensor.ndim==4:
        raise Exception("assumes a 4D tensor")
    if not tensor.shape[-1]==3:
        raise Exception("last dim needs to be 3 to plot")
    num_kernels = tensor.shape[0]
    num_rows = 1+ num_kernels // num_cols
    fig = plt.figure(figsize=(num_cols,num_rows))
    for i in range(tensor.shape[0]):
        ax1 = fig.add_subplot(num_rows,num_cols,i+1)
        ax1.imshow(tensor[i])
        ax1.axis('off')
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
vgg = models.vgg16(pretrained=True)
mm = vgg.double()
filters = mm.modules
body_model = [i for i in mm.children()][0]
layer1 = body_model[0]
tensor = layer1.weight.data.numpy()
plot_kernels(tensor)

上面给出了此错误ValueError: Floating point image RGB values must be in the 0..1 range.

我的问题是我应该规范化并采用权重的绝对值来克服这个错误还是有其他方法吗? 如果我规范化并使用绝对值,我认为图形的含义会发生变化。

[[[[ 0.02240197 -1.22057354 -0.55051649]
   [-0.50310904  0.00891289  0.15427093]
   [ 0.42360783 -0.23392732 -0.56789106]]

  [[ 1.12248898  0.99013627  1.6526649 ]
   [ 1.09936976  2.39608836  1.83921957]
   [ 1.64557672  1.4093554   0.76332706]]

  [[ 0.26969245 -1.2997849  -0.64577204]
   [-1.88377869 -2.0100112  -1.43068039]
   [-0.44531786 -1.67845118 -1.33723605]]]


 [[[ 0.71286005  1.45265901  0.64986968]
   [ 0.75984162  1.8061738   1.06934202]
   [-0.08650422  0.83452386 -0.04468433]]

  [[-1.36591709 -2.01630116 -1.54488969]
   [-1.46221244 -2.5365622  -1.91758668]
   [-0.88827479 -1.59151018 -1.47308767]]

  [[ 0.93600738  0.98174071  1.12213969]
   [ 1.03908169  0.83749604  1.09565806]
   [ 0.71188802  0.85773659  0.86840987]]]


 [[[-0.48592842  0.2971966   1.3365227 ]
   [ 0.47920835 -0.18186836  0.59673625]
   [-0.81358945  1.23862112  0.13635623]]

  [[-0.75361633 -1.074965    0.70477796]
   [ 1.24439156 -1.53563368 -1.03012812]
   [ 0.97597247  0.83084011 -1.81764793]]

  [[-0.80762428 -0.62829626  1.37428832]
   [ 1.01448071 -0.81775147 -0.41943246]
   [ 1.02848887  1.39178836 -1.36779451]]]


 ..., 
 [[[ 1.28134537 -0.00482408  0.71610934]
   [ 0.95264435 -0.09291686 -0.28001019]
   [ 1.34494913  0.64477581  0.96984017]]

  [[-0.34442815 -1.40002513  1.66856039]
   [-2.21281362 -3.24513769 -1.17751861]
   [-0.93520379 -1.99811196  0.72937071]]

  [[ 0.63388056 -0.17022935  2.06905985]
   [-0.7285465  -1.24722099  0.30488953]
   [ 0.24900314 -0.19559766  1.45432627]]]


 [[[-0.80684513  2.1764245  -0.73765725]
   [-1.35886598  1.71875226 -1.73327696]
   [-0.75233924  2.14700699 -0.71064663]]

  [[-0.79627383  2.21598244 -0.57396138]
   [-1.81044972  1.88310981 -1.63758397]
   [-0.6589964   2.013237   -0.48532376]]

  [[-0.3710472   1.4949851  -0.30245575]
   [-1.25448656  1.20453358 -1.29454732]
   [-0.56755757  1.30994892 -0.39370224]]]


 [[[-0.67361742 -3.69201088 -1.23768616]
   [ 3.12674141  1.70414758 -1.76272404]
   [-0.22565465  1.66484773  1.38172317]]

  [[ 0.28095332 -2.03035069  0.69989491]
   [ 1.97936332  1.76992691 -1.09842575]
   [-2.22433758  0.52577412  0.18292744]]

  [[ 0.48471382 -1.1984663   1.57565165]
   [ 1.09911084  1.31910467 -0.51982772]
   [-2.76202297 -0.47073677  0.03936549]]]]

1 个答案:

答案 0 :(得分:3)

听起来好像你已经知道你的价值不在那个范围内。是的,您必须将它们重新缩放到0.0 - 1.0范围内。我建议您希望保持负面与正面的可见度,但是让0.5成为新的“中性”点。缩放使得当前0.0值映射到0.5,并且您的最极值(最大幅度)缩放为0.0(如果为负)或1.0(如果为正)。

感谢矢量。看起来您的值在-2.25到+2.0的范围内。我建议重新调整new = (1/(2*2.25)) * old + 0.5