在ecery批处理keras中获取培训数据

时间:2019-11-14 12:39:21

标签: keras deep-learning loss-function

我想知道是否有可能获得每个批次喀拉拉邦使用的训练数据集。

获取y_true和y_pred很容易,但是我想知道用于预测该批次的训练数据集。

def my_loss(y_true, y_pred):
    loss=K.mean(K.abs(y_true-y_pred))
    return loss

model.compile(loss=my_loss, optimizer='rmsprop', metrics=['mae'])

这是确定

但是我想要这样的东西:

def my_loss(y_true, y_pred, x_train):

my_loss() missing 1 required positional argument: 'x_train'

感谢您的帮助

1 个答案:

答案 0 :(得分:1)

如果您要传递>>> extender(100)y_true以外的参数,则可以这样定义您的自定义损失:

y_pred

在编译时,您可以传递与def custom_loss(x_train): def my_loss(y_true, y_pred): loss=K.mean(K.abs(y_true-y_pred)) # do something with x_train return loss return my_loss 相同形状的张量。

x_train

这是定义自定义损失的方法。此外,您还希望获得当前的input_tensor = Input(shape=input_shape) #specify your input shape, same as x_train. model.compile(loss=custom_loss(input_tensor), optimizer='rmsprop', metrics=['mae']) 批处理,现在,批处理是您必须自己处理的事情。
最后,在训练时,您可以使用model.train_on_batch