如何使用过滤自定义张量流损失函数?

时间:2020-10-19 02:19:50

标签: python tensorflow keras loss-function

我知道我们可以像以下方法一样创建自定义损失函数。

def custom_loss(y_true, y_pred): 
    y_pred = K.round(y_pred / 1000) * 1000 # Rounded as 1000 unit
    loss = tf.keras.losses.MSE(y_true, y_pred) 
    return K.sqrt(loss)

model = tf.keras.Sequential()
model.add(feature_layer)
model.add(layers.Dense(1, activation="relu"))
model.compile(loss='mse', optimizer= opt, 
                metrics = [tf.keras.metrics.RootMeanSquaredError(), custom_loss])
opt = tf.keras.optimizers.Adam(learning_rate= alpha)

但是,我不知道如何在自定义损失函数中使用过滤器(因为它看起来仅支持Keras backend函数。)

对于过滤器功能示例,仅计算y_true >= 1000时的损耗。

有什么建议吗?我想在训练过程中监视过滤后的自定义损失函数。

谢谢

1 个答案:

答案 0 :(得分:0)

您可以使用tensorflow和tensorflow.keras.backend方法来实现这一目标。

import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as kb

x = np.array(range(1,11))
y = 2*x
x = x.reshape(2,5)
y = y.reshape(2,5)
x = x.astype(np.float32)
y = y.astype(np.float32)
def custom_loss(y_true, y_pred): 
    # Calculate difference only if condition is met, else assign 0
    diff = tf.where(y_true >= 5, y_true - y_pred, 0)
    sum_of_squares = kb.sum(kb.square(diff),axis=-1)
    # count of values where diff != 0
    value_counts = kb.sum(tf.where(diff != 0, 1, 0),axis=-1)
    value_counts = tf.cast(value_counts,sum_of_squares.dtype)
    custom_loss = sum_of_squares/value_counts
    custom_loss = kb.sqrt(custom_loss)
    return custom_loss
tf.random.set_seed(52)
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(10,activation='relu'),
    tf.keras.layers.Dense(1)
])
tf.keras.backend.clear_session()
model.compile(loss='mse',metrics=[tf.keras.metrics.RootMeanSquaredError(), custom_loss])
model.fit(x,y,epochs=10)

Epoch 1/10
1/1 [==============================] - 0s 996us/step - loss: 277.5225 - root_mean_squared_error: 16.6590 - custom_loss: 16.3645
Epoch 2/10
1/1 [==============================] - 0s 997us/step - loss: 264.7605 - root_mean_squared_error: 16.2715 - custom_loss: 16.0073
Epoch 3/10
1/1 [==============================] - 0s 998us/step - loss: 255.8174 - root_mean_squared_error: 15.9943 - custom_loss: 15.7518
Epoch 4/10
1/1 [==============================] - 0s 996us/step - loss: 248.5119 - root_mean_squared_error: 15.7643 - custom_loss: 15.5396
Epoch 5/10
1/1 [==============================] - 0s 998us/step - loss: 242.1583 - root_mean_squared_error: 15.5614 - custom_loss: 15.3526
Epoch 6/10
1/1 [==============================] - 0s 998us/step - loss: 236.4406 - root_mean_squared_error: 15.3766 - custom_loss: 15.1821
Epoch 7/10
1/1 [==============================] - 0s 0s/step - loss: 231.1830 - root_mean_squared_error: 15.2047 - custom_loss: 15.0235
Epoch 8/10
1/1 [==============================] - 0s 997us/step - loss: 226.2768 - root_mean_squared_error: 15.0425 - custom_loss: 14.8739
Epoch 9/10
1/1 [==============================] - 0s 2ms/step - loss: 221.6491 - root_mean_squared_error: 14.8879 - custom_loss: 14.7312
Epoch 10/10
1/1 [==============================] - 0s 999us/step - loss: 217.2489 - root_mean_squared_error: 14.7394 - custom_loss: 14.5941