我知道我们可以像以下方法一样创建自定义损失函数。
def custom_loss(y_true, y_pred):
y_pred = K.round(y_pred / 1000) * 1000 # Rounded as 1000 unit
loss = tf.keras.losses.MSE(y_true, y_pred)
return K.sqrt(loss)
model = tf.keras.Sequential()
model.add(feature_layer)
model.add(layers.Dense(1, activation="relu"))
model.compile(loss='mse', optimizer= opt,
metrics = [tf.keras.metrics.RootMeanSquaredError(), custom_loss])
opt = tf.keras.optimizers.Adam(learning_rate= alpha)
但是,我不知道如何在自定义损失函数中使用过滤器(因为它看起来仅支持Keras backend
函数。)
对于过滤器功能示例,仅计算y_true >= 1000
时的损耗。
有什么建议吗?我想在训练过程中监视过滤后的自定义损失函数。
谢谢
答案 0 :(得分:0)
您可以使用tensorflow和tensorflow.keras.backend方法来实现这一目标。
import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as kb
x = np.array(range(1,11))
y = 2*x
x = x.reshape(2,5)
y = y.reshape(2,5)
x = x.astype(np.float32)
y = y.astype(np.float32)
def custom_loss(y_true, y_pred):
# Calculate difference only if condition is met, else assign 0
diff = tf.where(y_true >= 5, y_true - y_pred, 0)
sum_of_squares = kb.sum(kb.square(diff),axis=-1)
# count of values where diff != 0
value_counts = kb.sum(tf.where(diff != 0, 1, 0),axis=-1)
value_counts = tf.cast(value_counts,sum_of_squares.dtype)
custom_loss = sum_of_squares/value_counts
custom_loss = kb.sqrt(custom_loss)
return custom_loss
tf.random.set_seed(52)
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10,activation='relu'),
tf.keras.layers.Dense(1)
])
tf.keras.backend.clear_session()
model.compile(loss='mse',metrics=[tf.keras.metrics.RootMeanSquaredError(), custom_loss])
model.fit(x,y,epochs=10)
Epoch 1/10
1/1 [==============================] - 0s 996us/step - loss: 277.5225 - root_mean_squared_error: 16.6590 - custom_loss: 16.3645
Epoch 2/10
1/1 [==============================] - 0s 997us/step - loss: 264.7605 - root_mean_squared_error: 16.2715 - custom_loss: 16.0073
Epoch 3/10
1/1 [==============================] - 0s 998us/step - loss: 255.8174 - root_mean_squared_error: 15.9943 - custom_loss: 15.7518
Epoch 4/10
1/1 [==============================] - 0s 996us/step - loss: 248.5119 - root_mean_squared_error: 15.7643 - custom_loss: 15.5396
Epoch 5/10
1/1 [==============================] - 0s 998us/step - loss: 242.1583 - root_mean_squared_error: 15.5614 - custom_loss: 15.3526
Epoch 6/10
1/1 [==============================] - 0s 998us/step - loss: 236.4406 - root_mean_squared_error: 15.3766 - custom_loss: 15.1821
Epoch 7/10
1/1 [==============================] - 0s 0s/step - loss: 231.1830 - root_mean_squared_error: 15.2047 - custom_loss: 15.0235
Epoch 8/10
1/1 [==============================] - 0s 997us/step - loss: 226.2768 - root_mean_squared_error: 15.0425 - custom_loss: 14.8739
Epoch 9/10
1/1 [==============================] - 0s 2ms/step - loss: 221.6491 - root_mean_squared_error: 14.8879 - custom_loss: 14.7312
Epoch 10/10
1/1 [==============================] - 0s 999us/step - loss: 217.2489 - root_mean_squared_error: 14.7394 - custom_loss: 14.5941