我在Tensorflow后端上使用Keras。
在损失函数中,我有一个张量,需要用1替换小于1的元素。
我可以在文档中看到很多可用的功能 https://www.tensorflow.org/api_docs/python/tf/keras/backend
但是我不确定该怎么做。
如果我这样做:
a_ = tf.Print(
message='a_shape',
input_=a_,
data=[tf.shape(a_)]
)
我得到的形状是:
y_shape[128]
我基本上需要遍历此张量,用1替换小于1的元素。
我如何使用keras tensorflow API做到这一点?
谢谢-
答案 0 :(得分:1)
如果a
是您的张量,则可以执行以下操作:
b = a*tf.cast(a>1, 'float32') + tf.cast(a<=1, 'float32')
答案 1 :(得分:1)
适用于所有后端的“ keras”答案:
isGreater = K.cast(K.greater(a_,1),K.floatx())
result = (a_*isGreater) + (1 - isGreater)