在keras中,您可以使用参数(y_true, y_pred)
定义自定义损失。
我怎么知道它们与哪些模式相关?
我的意思是,y_true
是具有batchSize
元素的张量。如何将这些元素与原始X
相关联?
我想知道y_true[0]
和相对的X[i]
之间的对应关系。
答案 0 :(得分:1)
所以您想要的是这样的损失函数
def custom_loss(y_true, y_pred, X):
因为您需要输入来进行损失计算。 据我所知,在Keras中这是不可能直接实现的。
一种可能的解决方法是拥有正在运行的索引:
X = ...
Y = ...
batch_size = ...
i = 0
def custom_loss(y_true, y_pred):
x = X[i*batch_size:(i+1)*batch_size]
loss = ...
i += 1
return loss
请确保在每个时期之后重置i
。您可以在传递给model.fit()
的{{3}}中进行此操作。另外,请确保将shuffle=False
传递给model.fit()
。