AttributeError:在自定义回调中无法访问validation_data

时间:2019-09-30 09:47:09

标签: python tensorflow keras deep-learning lstm

我用LSTM实现了一个预测模型,并编写了一个自定义回调,以便访问反向缩放输入上的一些其他指标。

Metrics类如下:

class Metrics(keras.callbacks.Callback):
    def __init__(self, scaler):
        self.scaler = scaler

    def on_train_begin(self, logs):
        self._data = []

    def on_epoch_end(self, batch, logs):
        val_data, val_target = self.validation_data[0], self.validation_data[1]

        # calculating and appending the metric here
        # self._data.append({metric})

        return

    def get_data(self):
        return self._data

然后我像这样使用它:

metrics = Metrics(scaler)

model = Sequential()
model.add(LSTM(32, 
                   return_sequences=True,
                   activation='tanh', 
                   input_shape=(dataset.X_train.shape[1], dataset.X_train.shape[2])))
# more layers and model.compile here

history = model.fit(dataset.X_train, 
                    dataset.y_train,  
                    epochs=EPOCHS,  
                    validation_data=(dataset.X_valid, dataset.y_valid), 
                    callbacks=[metrics])

有什么想法吗?

0 个答案:

没有答案