我正在使用CNN进行二进制图像分类任务。我想在Keras中使用以下代码查看卷积层的过滤器:
from mpl_toolkits.axes_grid1 import make_axes_locatable
def nice_imshow(ax, data, vmin=None, vmax=None, cmap=None):
"""Wrapper around pl.imshow"""
if cmap is None:
cmap = cm.jet
if vmin is None:
vmin = data.min()
if vmax is None:
vmax = data.max()
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
im = ax.imshow(data, vmin=vmin, vmax=vmax, interpolation='nearest', cmap=cmap)
pl.colorbar(im, cax=cax)
pl.savefig("/home/nd/Results/filter--{}".format(q) + '.jpg')
import numpy.ma as ma
def make_mosaic(imgs, nrows, ncols, border=1):
"""
Given a set of images with all the same shape, makes a
mosaic with nrows and ncols
"""
nimgs = imgs.shape[0]
imshape = imgs.shape[1:]
mosaic = ma.masked_all((nrows * imshape[0] + (nrows - 1) * border,
ncols * imshape[1] + (ncols - 1) * border),
dtype=np.float32)
paddedh = imshape[0] + border
paddedw = imshape[1] + border
for i in range(nimgs):
row = int(np.floor(i / ncols))
col = i % ncols
mosaic[row * paddedh:row * paddedh + imshape[0],
col * paddedw:col * paddedw + imshape[1]] = imgs[i]
return mosaic
# Visualize weights
filternumber=[1]
for q in filternumber:
W=model.layers[q].get_weights()[0][:,:,0,:]
W=np.swapaxes(W,0,2)
W = np.squeeze(W)
print("W shape : ", W.shape)
pl.figure(figsize=(15, 15))
pl.title('conv1 weights')
nice_imshow(pl.gca(), make_mosaic(W, 16, 16), cmap=cm.binary)
并出现以下错误:
Traceback (most recent call last):
File "<ipython-input-35-a2febeda1dfc>", line 51, in <module>
nice_imshow(pl.gca(), make_mosaic(W, 16, 16), cmap=cm.binary)
File "<ipython-input-35-a2febeda1dfc>", line 26, in make_mosaic
mosaic = ma.masked_all((nrows * imshape[0] + (nrows - 1) * border,
IndexError: tuple index out of range
我的模型摘要如下:
input_1(InputLayer)(无,30、30、1)0
conv2d_1(Conv2D)(无,30、30、64)128 input_1 [0] [0]
conv2d_3(Conv2D)(无,30、30、64)128 input_1 [0] [0]
max_pooling2d_1(MaxPooling2D)(无,30、30、1)0 input_1 [0] [0]
conv2d_2(Conv2D)(无,30、30、64)36928 conv2d_1 [0] [0]
conv2d_4(Conv2D)(无,30、30、64)102464 conv2d_3 [0] [0]
conv2d_5(Conv2D)(无,30、30、64)128 max_pooling2d_1 [0] [0]
concatenate_1(连接)(无,30、30、192)0 conv2d_2 [0] [0]
conv2d_4 [0] [0]
conv2d_5 [0] [0]
flatten_1(扁平)(无,172800)0 concatenate_1 [0] [0]
dense_1(密集)(无,2)345602 flatten_1 [0] [0]
如何解决问题。我的img尺寸是30 * 30。