在keras回调中使用带有自定义参数的自定义函数

时间:2019-07-09 13:10:55

标签: keras callback metaclass custom-function

我正在训练喀拉拉邦一个模型,我想在每个时期后绘制结果图。我知道keras回调提供了“ on_epoch_end”函数,如果一个人想在每个纪元后进行一些计算,可以重载,但是我的函数采用了一些附加参数,这些参数给出时会因元类错误而使代码崩溃。详细信息如下:

这是我现在的操作方式,运行正常:-

class NewCallback(Callback):

def on_epoch_end(self, epoch, logs={}):  #working fine, printing epoch after each epoch
    print("EPOCH IS: "+str(epoch))


epochs=5
batch_size = 16
model_saved=False
if model_saved:
    vae.load_weights(args.weights)
else:
    # train the autoencoder
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
           callbacks=[NewCallback()])

但是我想要这样的回调函数:-

class NewCallback(Callback,models,data,batch_size):
   def on_epoch_end(self, epoch, logs={}):
     print("EPOCH IS: "+str(epoch))
     x=models.predict(data)
     plt.plot(x)
     plt.savefig(epoch+".png")

如果我这样称呼它:

callbacks=[NewCallback(models, data, batch_size=batch_size)]

我收到此错误:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases 

我正在寻找一个更简单的解决方案来调用我的函数或解决此元类错误,任何帮助将不胜感激!

2 个答案:

答案 0 :(得分:2)

我认为您想做的是定义一个从回调派生的类,并将模型,数据等作为构造函数参数。所以:

class NewCallback(Callback):
    """ NewCallback descends from Callback
    """
    def __init__(self, models, data, batch_size):
        """ Save params in constructor
        """
        self.models = models

    def on_epoch_end(self, epoch, logs={}):
        x = self.models.predict(self.data)

答案 1 :(得分:0)

感谢您的上述回答。我会详细说明一下。请在您方便时更新答案并发表评论。

下面的 model.fit 函数将调用此自定义回调函数 CustomCallback 以使用张量板绘制混淆矩阵

# Call back function to save the confusion matrix using tensorboard.
class CustomCallback(tf.keras.callbacks.Callback):

    # Save all of your required parameter values in a constructor
    def __init__(self, model, test_images, test_labels, class_names):
        self.model = model
        self.test_images = test_images
        self.test_labels = test_labels
        self.class_names = class_names

    def on_epoch_end(self, epoch, logs):
        # Use the model to predict the values from the validation dataset.
        test_pred_raw = self.model.predict(self.test_images)
        test_pred = np.argmax(test_pred_raw, axis=1)
    
        # Calculate the confusion matrix.
        cm = sklearn.metrics.confusion_matrix(self.test_labels, test_pred)
        # Log the confusion matrix as an image summary.
        figure = plot_confusion_matrix(cm, class_names=self.class_names)
        cm_image = plot_to_image(figure)

        # Log the confusion matrix as an image summary.
        with file_writer_cm.as_default():
           tf.summary.image("Confusion Matrix", cm_image, step=epoch)

  # fit function  
  model.fit(
        train_images,
        train_labels,
        epochs=5,
        callbacks=CustomCallback(model, 
        test_images,
        test_labels,
        class_names 
        )])