tf.assign不会使用model._function_kwargs更新自定义回调init构造函数中的tf.Variable值

时间:2019-06-26 12:28:15

标签: python tensorflow keras callback

我需要创建一个自定义回调以获取目标值,即y_true和y_pred(预测值)。所以,我读了这篇文章:Create keras callback to save model predictions and targets for each batch during training

并创建了我的回调函数,就像在答案中创建的一样

from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf

class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        self.targets.append(K.eval(self.var_y_true))
        self.outputs.append(K.eval(self.var_y_pred))

# build a simple model
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(10,))])
model.compile(loss='mse', optimizer='adam')

# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
           tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}  # use `model._function_kwargs` if using `Model` instead of `Sequential`

当我添加on_epoch_end并尝试打印self.targets的值时。 我得到0的数组。 对于on_epoch_end,代码如下:

def on_epoch_end(self, epoch, logs={}):
    print(self.targets)

我的模型是使用功能性API Model创建的,并且已加载了预先训练的权重,而不是顺序权重。在将模型编译为model.compile之后,我实例化了callback并创建了fetches对象,并将其添加到train_function中,如下所示:

cbk = CollectOutputAndTarget()

fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
                   tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}

然后,我使用数据生成器调用model.fit_generator。我正在0s中得到self.targets,如果var_y_truevar_y_pred正在被model.targetsmodel.outputs更新,则不应该发生。另外,我不明白,如果我们已经为cbk.var_y_truecbk.var_y_pred分配了值,那么为什么我们需要使用model._function_kwargs

在设置model.train_function = None之后和调用fetches之前,我尝试使用fit_generator,但是仍然得到相同的结果。

0 个答案:

没有答案