我想在 tensorflow 中制作自定义模型。我创建了创建基本层的函数,例如 Conv2D、Dense、Flatten。 我坚持使用批量标准化实现。
我想在一个列表 trainable_variables
中包含所有 self.parameters
(参数)。因为我的“自定义”层建立在 tf.Module 上,所以我假设所有可训练的变量在 self.trainable_parameters
中都是可用的。目前 self.trainable_variables
不不包含 MyBatchNormalization
可训练变量。
以下 Colab 示例:
Colab Example - “创建模型”部分打印层名称和可训练参数。
我希望有工作的 BatchNormalization 层,可以在 train_on_batch
方法中更新(训练)可训练变量。
答案 0 :(得分:0)
我找到了解决方案 - 我刚刚添加了 __build(self, input_shape)
函数,它调用 super(MyBatchNormalization, self).build(input_shape)
。
实现如下:
class MyBatchNormalization(tf.keras.layers.BatchNormalization):
def __init__(self, input_shape, name=None):
super().__init__(name=name)
self.out_shape = input_shape
self.__build(input_shape)
def __build(self, input_shape):
super(MyBatchNormalization, self).build(input_shape)
我还没有完全测试过。但是我可以访问 MyBatchNormalization 的 trainable_variables
并且看起来很有希望。