在keras中,如何应用过滤器(where)功能?

时间:2018-04-10 13:59:26

标签: python tensorflow deep-learning keras

像functools包中的过滤功能一样,我希望在张量中找到超过0.5的元素。

这是代码,但不起作用。

def pred_overhalf(y_true, y_pred):

    return K.count_params( filter( lambda x : x > 0.5 , y_pred ) )
model.compile(optimizer = "adam" , loss = "mse", metrics = [ pred_overhalf])

有什么方法可以解决这个问题吗?我搜索keras后端文档,但我找不到任何解决方案

1 个答案:

答案 0 :(得分:0)

def pred_overhalf(y_true,y_pred):
    out = K.greater(y_pred,0.5)
    out = K.cast(out,K.floatx())

    #option 1
    return K.mean(out) #fraction of items greater than 0.5

    #option 2
    return K.sum(out) #total count (beware: this will consider all samples)