如何在Keras功能模型API中处理批量规范化的update_ops?

时间:2019-08-21 11:26:13

标签: python tensorflow keras model batch-normalization

我正在试用Keras模型的功能API,并尝试使用相同的模型和权重共享(也包括批归一化)来设置两个数据集流。当我在第一个数据流之上创建模型,然后在第二个数据流上调用模型时,我可以在张量板上看到其他更新操作,但是模型不包括这些新创建的操作。我的问题是,在这种情况下,一种好的编码方法会是什么样子,同时仍然能够使用传统的张量流会话,数据集迭代器以及自定义损失和优化器?

我已经附加了一些代码,这些代码显示了我所请求的行为实际上如何用于BatchNormalization层本身。不过,在这种情况下,最好还是对更新操作有更多的控制权,例如而不是遍历列表并检查名称,而直接以某种方式使更新操作与层本身的调用连接或返回,以便人们可以将更新操作直接关联到正确的数据流。

import tensorflow as tf
import shutil
import os.path as osp


shutil.rmtree(osp.join('/tmp', 'keras_model_testtb'), ignore_errors=True)
tb_saver = tf.summary.FileWriter(osp.join(
    '/tmp', 'keras_model_testtb',
))

input1 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
batchnorm = tf.keras.layers.BatchNormalization()
output1 = batchnorm(input1, training=tf.constant(True))
# following print shows:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
print(batchnorm.updates)
input2 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
output2 = batchnorm(input2, training=tf.constant(True))
# following print shows:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1_1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1_1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
# update ops of both layer calls are merged into one list and one has to check the name of the ops
# to use them correctly with optimizers
print(batchnorm.updates)

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    tb_saver.add_graph(session.graph)

以下是使用Keras Model API的一些示例代码。如您所见,仅在使用keras输入层的情况下,模型更新操作确实具有如上一个示例中的其他2个更新操作。对于数据集迭代器,似乎有所不同。

#!/usr/bin/env python3

import tensorflow as tf
import shutil
import os.path as osp


def reshape(sample):
    return tf.cast(tf.tile(sample[None, None, None], [2, 4, 3]), tf.float32)


ds = tf.data.Dataset.range(10)
ds = ds.map(reshape)
input_ds = ds.make_one_shot_iterator().get_next()


shutil.rmtree(osp.join('/tmp', 'keras_model_testtb'), ignore_errors=True)
tb_saver = tf.summary.FileWriter(osp.join(
    '/tmp', 'keras_model_testtb',
))

input1 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
output1 = tf.keras.layers.BatchNormalization()(input1, training=tf.constant(True))
model = tf.keras.models.Model(inputs=input1, outputs=output1)
# following prints show:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
# []
print(model.updates)
print([n.name for n in tf.get_default_graph().as_graph_def().node if 'model' in n.name and 'AssignSubVariableOp' in n.name])
output_ds = model(input_ds)
# following prints show:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
# ['model/batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp', 'model/batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp']
print(model.updates)
print([n.name for n in tf.get_default_graph().as_graph_def().node if 'model' in n.name and 'AssignSubVariableOp' in n.name])
# Above I expected to have found the new update ops also in model.updates, as is the case when you call the model on keras Input tensors:
input2 = tf.keras.layers.Input(shape=(None, 3), dtype=tf.float32)
output2 = model(input2)
# following prints show:
# [
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'model_1/batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp' type=AssignSubVariableOp>,
#   <tf.Operation 'model_1/batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp' type=AssignSubVariableOp>
# ]
# [
#   'model/batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp', 'model/batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp',
#   'model_1/batch_normalization_v1/AssignMovingAvg/AssignSubVariableOp', 'model_1/batch_normalization_v1/AssignMovingAvg_1/AssignSubVariableOp'
# ]
print(model.updates)
print([n.name for n in tf.get_default_graph().as_graph_def().node if 'model' in n.name and 'AssignSubVariableOp' in n.name])

with tf.Session() as session:
    session.run(tf.global_variables_initializer())
    tb_saver.add_graph(session.graph)

0 个答案:

没有答案