在Keras自定义图层中获取批处理大小并使用Tensorflow操作(tf.Variable)

时间:2019-05-12 17:52:35

标签: python tensorflow keras

我想编写一个带有tensorflow操作的Keras自定义层,该操作需要将批处理大小作为输入。显然我在每个角落都在挣扎。

假设一个非常简单的层: (1)获取批量 (2)根据批量大小创建一个tf.Variable(我们称其为my_var),然后使用一些tf.random操作来更改my_var (3)最后,返回输入乘以my_var

到目前为止我尝试过的事情:

class TestLayer(Layer):

    def __init__(self, **kwargs):

        self.num_batch = None
        self.my_var = None

        super(TestLayer, self).__init__(**kwargs)

    def build(self, input_shape):

        self.batch_size = input_shape[0]

        var_init = tf.ones(self.batch_size, dtype = x.dtype)
        self.my_var = tf.Variable(var_init, trainable=False, validate_shape=False)

        # some tensorflow random operations to alter self.my_var

        super(TestLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):

        return self.my_var * x

    def compute_output_shape(self, input_shape):

        return input_shape

现在创建一个非常简单的模型:

# define model
input_layer = Input(shape = (2, 2, 3), name = 'input_layer')
x = TestLayer()(input_layer)

# connect model
my_mod = Model(inputs = input_layer, outputs = x)
my_mod.summary()

不幸的是,无论我在代码中尝试/更改什么,都会遇到多个错误,其中大多数错误的回溯都非常隐秘(ValueError:无法将部分已知的TensorShape转换为Tensor:或ValueError:不支持任何值。)。

有什么一般建议吗?预先感谢。

1 个答案:

答案 0 :(得分:1)

如果要创建大小为batch_size的变量,则需要指定批处理大小。另外,如果要打印摘要,tf.Variable必须具有固定的形状(validatate_shape=True),并且必须可广播以成功乘以输入:

import tensorflow as tf
from tensorflow.keras.layers import Layer, Input
from tensorflow.keras.models import Model

class TestLayer(Layer):

    def __init__(self, **kwargs):
        self.num_batch = None
        self.my_var = None
        super(TestLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.batch_size = input_shape[0]
        var_init = tf.ones(self.batch_size, dtype=tf.float32)[..., None, None, None]
        self.my_var = tf.Variable(var_init, trainable=False, validate_shape=True)
        super(TestLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        res = self.my_var * x
        return res

    def compute_output_shape(self, input_shape):
        return input_shape

# define model
input_layer = Input(shape=(2, 2, 3), name='input_layer', batch_size=10)
x = TestLayer()(input_layer)

# connect model
my_mod = Model(inputs=input_layer, outputs=x)
my_mod.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_layer (InputLayer)     (10, 2, 2, 3)             0         
_________________________________________________________________
test_layer (TestLayer)       (10, 2, 2, 3)             0         
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0