具有输入内核和偏差的自定义图层

时间:2018-05-01 02:14:21

标签: keras layer

在实现带有输入内核和偏差的自定义conv2d层时遇到问题。该内核和偏差是另一层A的输出,然后使用这些权重来执行conv2d。我只是希望这个图层使用重量而不是学习,所以如果这个图层不可训练,那么渐变会转移到A层,这意味着我希望A层可以训练

1 个答案:

答案 0 :(得分:1)

您不需要使用不可学习的参数编写此类图层。如果我理解正确,你需要以下内容。

import keras
from keras.layers import Conv2D, Dense, Input, Flatten, Lambda
from keras.models import Model
from keras import backend as K

img = Input(shape=(32,32,3), name='img_in')
# this is the standard way of calling a learnable conv kernel and bias
# but kernel and bias are independent of input
x = Conv2D( 64,(5,5),padding='same',name='StandardConv')(img)
# this is a custom way of calling a learnable conv kernel and bias
# but this time they are dependent on the input
img_flat = Flatten(name='flat')(img)
conv_kernel = Dense( 3*5*5*64, name='regConvKernel' )( img_flat )
conv_bias = Dense( 64, name='regConvBias' )( img_flat )
# of course, you need to use conv_kernel and conv_bias to apply conv operation
# and this happens here
def custom_conv( input_vars ) :
    x, kernel, bias = input_vars
    kernel = K.reshape( kernel, (5,5,3,64))
    bias = K.reshape( bias, [1,1,1,64])
    x = K.conv2d( x, kernel, padding='same' )
    x += bias
    return x
def custom_conv_shape( input_shapes ) :
    x_shape, kernel_shape, bias_shape = input_shapes
    return x_shape[:3] + bias_shape[-1:]
y = Lambda( custom_conv, output_shape=custom_conv_shape, name='CustomConv')([img, conv_kernel, conv_bias])
# define your final model
model = Model( inputs=img, outputs=[x,y], name='compareConv')

print model.summary()

# test use dummy numpy arrays
import numpy as np 
a = np.random.randn(1,32,32,3)
b, c = model.predict(a)
print "standard conv output shape =", b.shape
print "custom conv output shape =", c.shape

你会看到如下的输出。

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
img_in (InputLayer)             (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
flat (Flatten)                  (None, 3072)         0           img_in[0][0]                     
__________________________________________________________________________________________________
regConvKernel (Dense)           (None, 4800)         14750400    flat[0][0]                       
__________________________________________________________________________________________________
regConvBias (Dense)             (None, 64)           196672      flat[0][0]                       
__________________________________________________________________________________________________
StandardConv (Conv2D)           (None, 32, 32, 64)   4864        img_in[0][0]                     
__________________________________________________________________________________________________
CustomConv (Lambda)             (None, 32, 32, 64)   0           img_in[0][0]                     
                                                                 regConvKernel[0][0]              
                                                                 regConvBias[0][0]                
==================================================================================================
Total params: 14,951,936
Trainable params: 14,951,936
Non-trainable params: 0
__________________________________________________________________________________________________
None
standard conv output shape = (1, 32, 32, 64)
custom conv output shape = (1, 32, 32, 64)

当然,您可以使用不同的内核大小或填充方案。您可以考虑更合理的方式来估算conv_kernelconv_bias,而不是直接从输入中回归它们。