使用Deep Replay软件包可视化ResNet中的重量初始化

时间:2019-07-15 09:30:25

标签: python matplotlib subplot resnet

我正在尝试使用MNIST数据集和Deep replay来训练ResNet,但是问题是我使用this method来可视化深度神经网络中的权重初始化,并且遇到了以下错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/anaconda3/envs/CR7/lib/python3.6/site-packages/seaborn/utils.py in categorical_order(values, order)
    525             try:
--> 526                 order = values.cat.categories
    527             except (TypeError, AttributeError):

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/generic.py in __getattr__(self, name)
   4371                 name in self._accessors):
-> 4372             return object.__getattribute__(self, name)
   4373         else:

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/accessor.py in __get__(self, obj, cls)
    132             return self._accessor
--> 133         accessor_obj = self._accessor(obj)
    134         # Replace the property with the accessor object. Inspired by:

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/arrays/categorical.py in __init__(self, data)
   2376     def __init__(self, data):
-> 2377         self._validate(data)
   2378         self.categorical = data.values

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/arrays/categorical.py in _validate(data)
   2385         if not is_categorical_dtype(data.dtype):
-> 2386             raise AttributeError("Can only use .cat accessor with a "
   2387                                  "'category' dtype")

AttributeError: Can only use .cat accessor with a 'category' dtype

During handling of the above exception, another exception occurred:

OverflowError                             Traceback (most recent call last)
<ipython-input-26-260535e317aa> in <module>
     12 av = replay.build_outputs(ax_activations, exclude_outputs=True, include_inputs=False)
     13 
---> 14 fig = compose_plots([wv, gv, zv, av], epoch=0, title=r'ResNet34 - 1 epoch')

~/anaconda3/envs/CR7/lib/python3.6/site-packages/deepreplay/plot.py in compose_plots(objects, epoch, title)
    145 
    146     for obj in objects:
--> 147         getattr(obj.__class__, '_update')(epoch, obj)
    148         for ax, ax_title in zip(obj.axes, obj.title):
    149             ax.set_title(ax_title)

~/anaconda3/envs/CR7/lib/python3.6/site-packages/deepreplay/plot.py in _update(i, lv, epoch_start)
    624 
    625         lv.ax.clear()
--> 626         sns.violinplot(data=df, x='layers', y='values', ax=lv.ax, cut=0, palette=lv.palette, scale='width')
    627         lv.ax.set_xticklabels(df.layers.unique())
    628         lv.ax.set_xlabel('Layers')

~/anaconda3/envs/CR7/lib/python3.6/site-packages/seaborn/categorical.py in violinplot(x, y, hue, data, order, hue_order, bw, cut, scale, scale_hue, gridsize, width, inner, split, dodge, orient, linewidth, color, palette, saturation, ax, **kwargs)
   2385                              bw, cut, scale, scale_hue, gridsize,
   2386                              width, inner, split, dodge, orient, linewidth,
-> 2387                              color, palette, saturation)
   2388 
   2389     if ax is None:

~/anaconda3/envs/CR7/lib/python3.6/site-packages/seaborn/categorical.py in __init__(self, x, y, hue, data, order, hue_order, bw, cut, scale, scale_hue, gridsize, width, inner, split, dodge, orient, linewidth, color, palette, saturation)
    560                  color, palette, saturation):
    561 
--> 562         self.establish_variables(x, y, hue, data, orient, order, hue_order)
    563         self.establish_colors(color, palette, saturation)
    564         self.estimate_densities(bw, cut, scale, scale_hue, gridsize)

~/anaconda3/envs/CR7/lib/python3.6/site-packages/seaborn/categorical.py in establish_variables(self, x, y, hue, data, orient, order, hue_order, units)
    201 
    202                 # Get the order on the categorical axis
--> 203                 group_names = categorical_order(groups, order)
    204 
    205                 # Group the numeric data

~/anaconda3/envs/CR7/lib/python3.6/site-packages/seaborn/utils.py in categorical_order(values, order)
    527             except (TypeError, AttributeError):
    528                 try:
--> 529                     order = values.unique()
    530                 except AttributeError:
    531                     order = pd.unique(values)

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/series.py in unique(self)
   1491         Categories (3, object): [a < b < c]
   1492         """
-> 1493         result = super(Series, self).unique()
   1494 
   1495         if is_datetime64tz_dtype(self.dtype):

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/base.py in unique(self)
   1047         else:
   1048             from pandas.core.algorithms import unique1d
-> 1049             result = unique1d(values)
   1050 
   1051         return result

~/anaconda3/envs/CR7/lib/python3.6/site-packages/pandas/core/algorithms.py in unique(values)
    365     htable, _, values, dtype, ndtype = _get_hashtable_algo(values)
    366 
--> 367     table = htable(len(values))
    368     uniques = table.unique(values)
    369     uniques = _reconstruct_data(uniques, dtype, original)

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.StringHashTable.__init__()

OverflowError: value too large to convert to int

我的代码如下:

filename = 'hyperparms_in_action.h5'
group_name = 'part2'

replaydata = ReplayData(X_train, Y_train, filename=filename, group_name=group_name, model=model)

replay = Replay(replay_filename=filename, group_name=group_name)

fig = plt.figure(figsize=(100, 30))
ax_zvalues = plt.subplot2grid((4, 1), (0, 0))
ax_weights = plt.subplot2grid((4, 1), (1, 0))
ax_activations = plt.subplot2grid((4, 1), (2, 0))
ax_gradients = plt.subplot2grid((4, 1), (3, 0))

wv = replay.build_weights(ax_weights)
gv = replay.build_gradients(ax_gradients)
zv = replay.build_outputs(ax_zvalues, before_activation=True, exclude_outputs=True, include_inputs=False)
av = replay.build_outputs(ax_activations, exclude_outputs=True, include_inputs=False)

fig = compose_plots([wv, gv, zv, av], epoch=0, title=r'ResNet34 - 1 epoch')

长时间考虑到最后,抛出一个错误,它打印出该图: img

如何防止该错误并打印深层神经网络图?也许问题在于子图,因为海量数据无法绘制为子图。如何以正确的方式或单独打印图?

0 个答案:

没有答案