自定义TensorFlow层未注册的正则化损失

时间:2019-01-30 11:32:38

标签: python tensorflow

我已经创建了一个自定义层,即经过大量简化,相关代码为:

class MyConv2D(tf.keras.layers.Layer):
  def __init__self, kernel_regularizer="glorot_uniform", ...):
    super(MyConv2D, self).__init__(...)
    self.kernel_regularizer=regularizers.get(kernel_regularizer)
    ...

  def build(self, input_shape):
    self.kernel = self.add_weight(
      name="kernel",
      ...
      regularizer=self.kernel_regularizer)
    self.built = True

tf.keras模型中使用该图层时,该图层将按预期工作。 Hovewer,我还将公开一个tf.layers类型的my_conv2d版本,该版本大致实现为:

def my_conv2d(inputs, kernel_regularizer="glorot_uniform"):
  conv2d = MyConv2D(kernel_regularizer=kernel_regularizer)
  return conv2d.apply(inputs)

问题在于,如果我现在打电话给tf.losses.get_regularization_losses(),我会得到空列表。如果我将my_conv2d的实现更改为以下内容:

def my_conv2d(inputs, kernel_regularizer="glorot_uniform"):
  conv2d = MyConv2D(kernel_regularizer=kernel_regularizer)
  print(conv2d.losses)
  return conv2d.apply(inputs)

我看到它正确地打印了正则化损失。但是,这仅在单GPU训练中有效。使用MirroredStrategy_tensor_conversion_mirrored行的assert not as_ref内得到了一个异常。我认为这是因为我正在尝试在创建两个塔的上下文处理程序中打印张量。

我在做什么错?如何以my_conv2d层的用户身份访问正则化损失,以及如何在my_conv2d层中注册损失?与tf.layers.conv2d的实现相比,我看不到我缺少任何代码。

0 个答案:

没有答案