Keras自定义损失:我如何知道与y_pred和y_true对应的模式?

时间:2020-01-16 09:54:15

标签: tensorflow keras deep-learning

在keras中,您可以使用参数(y_true, y_pred)定义自定义损失。 我怎么知道它们与哪些模式相关? 我的意思是,y_true是具有batchSize元素的张量。如何将这些元素与原始X相关联? 我想知道y_true[0]和相对的X[i]之间的对应关系。

1 个答案:

答案 0 :(得分:1)

所以您想要的是这样的损失函数

def custom_loss(y_true, y_pred, X):

因为您需要输入来进行损失计算。 据我所知,在Keras中这是不可能直接实现的。

一种可能的解决方法是拥有正在运行的索引:

X = ...
Y = ...
batch_size = ...
i = 0
def custom_loss(y_true, y_pred):
    x = X[i*batch_size:(i+1)*batch_size]
    loss = ...
    i += 1
    return loss

请确保在每个时期之后重置i。您可以在传递给model.fit()的{​​{3}}中进行此操作。另外,请确保将shuffle=False传递给model.fit()