如何在张量流的回调中访问训练和测试数据?

时间:2019-12-03 22:32:14

标签: python tensorflow callback

>     import tensorflow as tf
>     
>     class MyMetric(tf.keras.callbacks.Callback):
>        def on_epoch_end(self,epoch,logs={}):
>            # how to access X_train and X_val here
> 
>     ...
>     model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric()]

我正在尝试使用回调在Tensorflow 2.0中实现自定义指标。在on_epoch_end方法中,我需要访问fit方法提供的训练和验证数据(整个样本,而不是批次)。有什么办法吗?谢谢!

2 个答案:

答案 0 :(得分:3)

接受训练和测试数据集作为自定义回调类的初始化参数,然后在 on_epoch_end 方法中使用它。

像这样

class MyMetric(keras.callbacks.Callback):

  def __init__(self, X_test):
    self.X_test = X_test

在调用 fit 时,将测试集作为参数传递给您的自定义回调,如下所示

model.fit(X_train,y_train,batch_size=32,epochs=10,validation_data=(X_val,y_val),shuffle=True,callbacks=[MyMetric(X_test)]

关于https://keras.io/guides/writing_your_own_callbacks/的更多详情

答案 1 :(得分:0)

您可以编辑.fit函数并传递额外的列表或队列,然后将额外的参数传递给回调函数...可能是一个队列,然后由另一个线程或函数处理该队列。

我对Paramiko库做了类似的修改,并且效果很好well