keras在训练过程中更改参数

时间:2018-07-26 16:46:29

标签: machine-learning keras

我有一个定制的层来执行简单的线性变换。像x*w+b。我想在训练期间更改w and b,可以吗?例如,我想在第一次迭代中使用w1,在第二次迭代中使用w2。(w1w2由我自己定义)。

1 个答案:

答案 0 :(得分:0)

当然可以,但是您需要以一种聪明的方式来做到。这是您可以使用的一些代码。

from keras import backend as K
from keras.layers import *
from keras.models import *
import numpy as np 

class MyDense( Layer ) :
    def __init__( self, units=64, use_bias=True, **kwargs ) :
        super(MyDense, self).__init__( **kwargs )
        self.units = units
        self.use_bias = use_bias
        return 
    def build( self, input_shape ) :
        input_dim = input_shape[-1]
        self.count = 0
        self.w1 = self.add_weight(shape=(input_dim, self.units), initializer='glorot_uniform', name='w1')
        self.w0 = self.add_weight(shape=(input_dim, self.units), initializer='glorot_uniform', name='w0')
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),initializer='glorot_uniform',name='bias' )
        else:
            self.bias = None
        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True
        return
    def call( self, x ) :
        if self.count % 2 == 1 :
            c0, c1 = 0, 1
        else :
            c0, c1 = 1, 0
        w = c0 * self.w0 + c1 * self.w1
        self.count += 1
        output = K.dot( x, w )
        if self.use_bias:
            output = K.bias_add(output, self.bias, data_format='channels_last')
        return output
    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) >= 2
        assert input_shape[-1]
        output_shape = list(input_shape)
        output_shape[-1] = self.units
        return tuple(output_shape)

# define a dummy model
x = Input(shape=(128,))
y = MyDense(10)(x)
y = Dense(1, activation='sigmoid')(y)
model = Model(inputs=x, outputs=y)
print model.summary()

# get some dummy data
a = np.random.randn(100,128)
b = (np.random.randn(100,) > 0).astype('int32')

# compile and train
model.compile('adam', 'binary_crossentropy')
model.fit( a, b )

请注意:以下代码与我们上面的操作等效,但无法正常工作!

if self.count % 2 == 1 :
    w = self.w0
else :
    w = self.w1

为什么?因为一个变量具有zero梯度(前一种实现)并不等同于具有None梯度(后一种实现)。