我需要创建一个自定义回调以获取目标值,即y_true和y_pred(预测值)。所以,我读了这篇文章:Create keras callback to save model predictions and targets for each batch during training
并创建了我的回调函数,就像在答案中创建的一样
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
class CollectOutputAndTarget(Callback):
def __init__(self):
super(CollectOutputAndTarget, self).__init__()
self.targets = [] # collect y_true batches
self.outputs = [] # collect y_pred batches
# the shape of these 2 variables will change according to batch shape
# to handle the "last batch", specify `validate_shape=False`
self.var_y_true = tf.Variable(0., validate_shape=False)
self.var_y_pred = tf.Variable(0., validate_shape=False)
def on_batch_end(self, batch, logs=None):
# evaluate the variables and save them into lists
self.targets.append(K.eval(self.var_y_true))
self.outputs.append(K.eval(self.var_y_pred))
# build a simple model
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(10,))])
model.compile(loss='mse', optimizer='adam')
# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches} # use `model._function_kwargs` if using `Model` instead of `Sequential`
当我添加on_epoch_end
并尝试打印self.targets
的值时。
我得到0的数组。
对于on_epoch_end
,代码如下:
def on_epoch_end(self, epoch, logs={}):
print(self.targets)
我的模型是使用功能性API Model
创建的,并且已加载了预先训练的权重,而不是顺序权重。在将模型编译为model.compile
之后,我实例化了callback
并创建了fetches
对象,并将其添加到train_function
中,如下所示:
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}
然后,我使用数据生成器调用model.fit_generator
。我正在0s
中得到self.targets
,如果var_y_true
和var_y_pred
正在被model.targets
和model.outputs
更新,则不应该发生。另外,我不明白,如果我们已经为cbk.var_y_true
和cbk.var_y_pred
分配了值,那么为什么我们需要使用model._function_kwargs
?
在设置model.train_function = None
之后和调用fetches
之前,我尝试使用fit_generator
,但是仍然得到相同的结果。