我正在尝试在Keras的自定义损失函数中找到零的数量。
def root_mean_squared_error(y_true, y_pred):
在此处输入此损失函数的地方:
model.compile(optimizer=sgd,
loss=root_mean_squared_error,
metrics=[metrics.mse, root_mean_squared_error])
我正在尝试在数组“ y_true”中查找非零值的数量,然后将我的数字除以该值。
如何找到y_true中非零元素的数量?
答案 0 :(得分:3)
您可以通过Keras后端使用tf.count_nonzero API。
from keras import backend as K
import numpy as np
def custom_loss(y_true, y_pred):
return y_pred / K.cast(K.tf.count_nonzero(y_true), K.tf.float32)
y_t = K.placeholder((1,2))
y_p = K.placeholder((1,2))
loss = custom_loss(y_t, y_p)
print(K.get_session().run(loss, {y_t: np.array([[1,1]]), y_p: np.array([[2,4]])}))
[[1。 2。]]
答案 1 :(得分:1)
也许您可以对NumPy数组使用布尔条件y_true != 0
:
z = np.array( y_true != 0 )
# Check the shape of z array.
print( z.shape )
count = z.shape[ 0 ]
在这里,计数应为y_true != 0
条件为真的元素的数量。