我正在使用keras,我想实现一个改进的卷积层。我目前正在尝试使用im2col
函数:
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
# First figure out what the size of the output should be
N, H, W, C = x_shape
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = (H + 2 * padding - field_height) // stride + 1
out_width = (W + 2 * padding - field_width) // stride + 1
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
return (k, i, j)
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
""" An implementation of im2col based on some fancy indexing """
# Zero-pad the input
p = padding
x_padded = np.pad(x, ((0, 0), (p, p), (p, p), (0, 0)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding,
stride)
cols = x_padded[:, i, j, k]
C = x.shape[-1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
我不知道如何在keras后端中实现cols = x_padded[:, i, j, k]
。我尝试使用keras.backend.gather
和tf.gather_nd
,但没有结果。有没有一种方法可以实现这种索引编制而不丢失向量化?谢谢。