基于输入的输出子集的自定义损失函数

时间:2019-06-27 13:11:37

标签: python tensorflow keras

我想创建一个损失函数,其中仅对输出的子集计算MSE。子集取决于输入数据。我用这个问题的答案来弄清楚如何根据输入数据创建自定义函数:

Custom loss function in Keras based on the input data

但是,我在实现自定义功能时遇到了麻烦。

这是我整理的内容。

def custom_loss(input_tensor):


    def loss(y_true, y_pred):
        board = input_tensor[:81]
        answer_vector = board == .5
        #assert np.sum(answer_vector) > 0

        return K.mean(K.square(y_pred * answer_vector - y_true), axis=-1)
    return loss


def build_model(input_size, output_size):
    learning_rate = .001
    a = Input(shape=(input_size,))
    b = Dense(60, activation='relu')(a)
    b = Dense(60, activation='relu')(b)
    b = Dense(60, activation='relu')(b)
    b = Dense(output_size, activation='linear')(b)
    model = Model(inputs=a, outputs=b)
    model.compile(loss=custom_loss(a), optimizer=Adam(lr=learning_rate))

    return model

model = build_model(83, 81)

我希望MSE在电路板不等于0.5的任何地方都将输出视为0。 (真实值是一个热编码的编码,其中一个在子集中)。由于某种原因,我的输出总是被视为零。换句话说,自定义损失函数似乎并没有找到木板等于0.5的任何地方。

我无法确定我是在错误地解释尺寸还是由于张量而导致比较失败,或者是否只有通常更简单的方法来执行我正在尝试的操作。

1 个答案:

答案 0 :(得分:0)

问题在于answer_vector = board == .5不是您认为的那样。它不是张量,而是布尔值False,因为board是张量,而0.5是数字:

a = tf.constant([0.5, 0.5])
print(a == 0.5) # False

现在,a * False是零的向量:

with tf.Session() as sess:
   print(sess.run(a * False)) # [0.0, 0.0]

您需要使用tf.equal而不是==。另一个可能的陷阱是比较浮点数与相等是危险的,请参见例如What's wrong with using == to compare floats in Java?