如何使用可训练的参数自定义损失函数?

时间:2019-06-19 15:03:04

标签: keras keras-layer

我想用训练参数自定义以下损失函数,其中λ是正则项,w是可训练参数。我自定义损失层,并将损失添加到模型中。但是,它发生以下错误。我不知道哪里出了问题,我的自定义损失层是否正确?

损失函数= binary_crossentropy(y_true,y_pred)+λ| w |

1。我自定义损失层。

class BCLoss(Layer):

def __init__(self,gamma_init='zero',lamada = 0.00005,**kwargs):
    self.gamma_init = initializations.get(gamma_init)
    self.lamada = lamada
    super(BCLoss, self).__init__(**kwargs)
def build(self, input_shape):
    self.gamma_init = self.add_weight(name='gamma',
                                 shape=(1,),
                                 initializer='uniform',
                                 trainable=True)
    super(BCLoss,self).build(input_shape)

def call(self, inputs, **kwargs):
    y_true, y_pred = inputs
    loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1) + self.lamada* K.abs(self.gamma_init)
    return loss

def compute_output_shape(self, input_shape):
    return input_shape

2。训练模型

    def build_model(self):
    img_input = layers.Input(shape=self.input_shape, name='img_input')
    y_true = layers.Input(shape=(1,),name='y_true')
    nb_channels = self.growth_rate

    # Initial convolution layer
    layer_1 = layers.Convolution2D(2 * self.growth_rate, (1, 1), strides=(2, 2),
                             kernel_regularizer=keras.regularizers.l2(self.weight_decay))(img_input)

    average_1 = layers.AveragePooling2D((2, 2), strides=(4, 4))(layer_1)

    x = layers.BatchNormalization()(layer_1)
    x = layers.Activation('relu')(x)

    x = layers.MaxPooling2D()(x)

    # Building dense blocks
    skip_layer = [average_1]

    for block in range(self.dense_blocks - 1):
        # Add dense block
        x, nb_channels = self.dense_block(x, self.dense_layers[block], nb_channels, self.growth_rate,
                                          self.dropout_rate, self.bottleneck, self.weight_decay)

        # Add transition_block
        x_left,x_right = self.transition_layer(x, nb_channels, self.dropout_rate, self.compression, self.weight_decay)
        x = x_left
        nb_channels = int(nb_channels * self.compression)
        skip_layer.append(x_right)
    # Add last dense block without transition but for that with global average pooling
    x, nb_channels = self.dense_block(x, self.dense_layers[-1], nb_channels,
                                      self.growth_rate, self.dropout_rate, self.weight_decay,skip_layer =skip_layer)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    y_pred = layers.Dense(self.nb_classes, activation='sigmoid')(x)

    model = keras.Model(inputs=[img_input,y_true], outputs=y_pred, name='densenet')

    loss = BCLoss()([y_true,y_pred])
    model.add_loss(loss)
    return model

3。编译模型

model.compile(loss=None, optimizer=Adam(lr=0.001), metrics=['accuracy'])

4。错误

Traceback (most recent call last):
File "G:/Workspace/workspace_python/Lung_EGRF/train.py", line 109, in <module>
train_model(config['training_data'],config['testing_data'], config['model_file'],config['input_shape'])
File "G:/Workspace/workspace_python/Lung_EGRF/train.py", line 97, in train_model
early_stopping_patience=50))
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\engine\training.py", line 1418, in fit_generator
initial_epoch=initial_epoch)
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\engine\training_generator.py", line 217, in fit_generator
class_weight=class_weight)
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\engine\training.py", line 1211, in train_on_batch
class_weight=class_weight)
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\engine\training.py", line 789, in _standardize_user_data
exception_prefix='target')
File "D:\Anaconda3\envs\py36\lib\site-packages\keras\engine\training_utils.py", line 63, in standardize_input_data
'expected no data, but got:', data)
ValueError: ('Error when checking model target: expected no data, but got:', 
array([0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0,
   1, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int64))

 Process finished with exit code 1

0 个答案:

没有答案