tf.keras.backend替换张量值(小于1)的方法

时间:2018-07-05 15:24:54

标签: python tensorflow machine-learning keras tensor

我在Tensorflow后端上使用Keras。

在损失函数中,我有一个张量,需要用1替换小于1的元素。

我可以在文档中看到很多可用的功能 https://www.tensorflow.org/api_docs/python/tf/keras/backend

但是我不确定该怎么做。

如果我这样做:

a_ = tf.Print(
    message='a_shape',
    input_=a_,
    data=[tf.shape(a_)]
)

我得到的形状是:

y_shape[128]

我基本上需要遍历此张量,用1替换小于1的元素。

我如何使用keras tensorflow API做到这一点?

谢谢-

2 个答案:

答案 0 :(得分:1)

如果a是您的张量,则可以执行以下操作:

b = a*tf.cast(a>1, 'float32') + tf.cast(a<=1, 'float32')

答案 1 :(得分:1)

适用于所有后端的“ keras”答案:

isGreater = K.cast(K.greater(a_,1),K.floatx())
result = (a_*isGreater) + (1 - isGreater)