如何在Keras数组中查找非零数

时间:2019-04-21 01:43:41

标签: python tensorflow keras

我正在尝试在Keras的自定义损失函数中找到零的数量。

def root_mean_squared_error(y_true, y_pred):

在此处输入此损失函数的地方:

model.compile(optimizer=sgd,
          loss=root_mean_squared_error,
          metrics=[metrics.mse, root_mean_squared_error])

我正在尝试在数组“ y_true”中查找非零值的数量,然后将我的数字除以该值。

如何找到y_true中非零元素的数量?

2 个答案:

答案 0 :(得分:3)

您可以通过Keras后端使用tf.count_nonzero API。

from keras import backend as K
import numpy as np

def custom_loss(y_true, y_pred):
    return y_pred / K.cast(K.tf.count_nonzero(y_true), K.tf.float32)

y_t = K.placeholder((1,2))
y_p = K.placeholder((1,2))
loss = custom_loss(y_t, y_p)
print(K.get_session().run(loss, {y_t: np.array([[1,1]]), y_p: np.array([[2,4]])}))

[[1。 2。]]

答案 1 :(得分:1)

也许您可以对NumPy数组使用布尔条件y_true != 0

z = np.array( y_true != 0 )
# Check the shape of z array.
print( z.shape )
count = z.shape[ 0 ]

在这里,计数应为y_true != 0条件为真的元素的数量。