如何在自定义tf.keras.layers.Layer中支持遮罩

时间:2019-03-15 06:31:12

标签: python tensorflow keras

我正在实现需要支持屏蔽的自定义tf.keras.layers.Layer

考虑以下情况

embedded = tf.keras.layer.Embedding(input_dim=vocab_size + 1, 
                                    output_dim=n_dims, 
                                    mask_zero=True)
x = MyCustomKerasLayers(embedded)

现在按文档

  

mask_zero:输入值0是否为应屏蔽的特殊“填充”值。这在使用可能需要可变长度输入的循环图层时很有用。 如果为True,则模型中的所有后续层都需要支持屏蔽,否则将引发异常。因此,如果将mask_zero设置为True,则无法在词汇表中使用索引0(input_dim应等于词汇表大小+ 1)。

我想知道这是什么意思?浏览TensorFlow's custom layers guidetf.keras.layer.Layer文档时,尚不清楚应该采取什么措施来支持遮罩

  1. 我如何支持遮罩?

  2. 如何从上一层访问遮罩?

  3. 假设输入(batch, time, channels)或`(批次,时间),掩码看起来会有所不同吗?它们的形状是什么?

  4. 如何将其传递到下一层?

1 个答案:

答案 0 :(得分:0)

  1. 要支持屏蔽,应在自定义层内实现compute_mask方法

  2. 要访问掩码,只需在call方法中添加参数mask作为第二个位置参数,即可使用它(例如call(self, inputs, mask=None)

  3. 这是无法猜测的,这是层负责计算蒙版的时间

  4. 一旦实现了compute_mask,就会自动将蒙版传递到下一层

示例:

class MyCustomKerasLayers(tf.keras.layers.Layer):
    def __init__(self, .......):
        ...

    def compute_mask(self, inputs, mask=None):
        # Just pass the received mask from previous layer, to the next layer or 
        # manipulate it if this layer changes the shape of the input
        return mask

    def call(self, input, mask=None):
        # using 'mask' you can access the mask passed from the previous layer

请注意,此示例只是通过蒙版,如果图层输出的形状与接收到的形状不同,则应在compute_mask中相应地更改蒙版以传递正确的蒙版。