TensorFlow 自定义层:获取实际批次大小

时间:2021-03-04 15:07:55

标签: python tensorflow keras

我想实现一个自定义 tf 层,该层执行涉及输入张量的实际批量大小的数学运算:

import tensorflow as tf
from   tensorflow import keras

class MyLayer(keras.layers.Layer):

    def build(self, input_shape):
        self.batch_size = input_shape[0]
        super().build(input_shape)

    def call(self,input):
        self.batch_size + 1 # do something with the batch size
        return input

但是,在构建图时,其值最初是 None,这破坏了 MyLayer 中的功能:

input = keras.Input(shape=(10,))
x     = MyLayer()(input)
TypeError: in user code:

    <ipython-input-41-98e23e82198d>:11 call  *
        self.batch_size + 1 # do something with the batch size

    TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

在模型构建完成后,有没有办法让这些层起作用?

1 个答案:

答案 0 :(得分:1)

使用 tf.shape 在图层的 call 方法中获取批量大小。

示例

import tensorflow as tf


# custom layer
class MyLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, x):
        bs = tf.shape(x)[0]
        return x, tf.add(bs, 1)
    
    
# network
x_in = tf.keras.Input(shape=(None, 10,))
x = MyLayer()(x_in)

# model def
model = tf.keras.models.Model(x_in, x)

# forward pass
_, shp = model(tf.random.normal([5, 10]))

# shape value
print(shp)
# tf.Tensor(6, shape=(), dtype=int32)