具有训练权重的keras子层功能

时间:2019-05-14 07:40:37

标签: python tensorflow keras keras-layer

我有一个带有功能的自定义层。功能具有w训练权重。如何在下面的代码中使用它。

class ConvCaps(tf.keras.layers.Layer):
    """This constructs a convolution capsule layer from a primary or convolution capsule layer.
        i: input capsules (32)
        o: output capsules (32)
        batch size: 24
        spatial dimension: 14x14
        kernel: 3x3
    :param inputs: a primary or convolution capsule layer with poses and activations
            pose: (24, 14, 14, 32, 4 * 4)
            activation: (24, 14, 14, 32)
    :param shape: the shape of convolution operation kernel, [kh, kw, i, o] = (3, 3, 32, 32)
    :param strides: often [1, 2, 2, 1] (stride 2), or [1, 1, 1, 1] (stride 1).
    :param iterations: number of iterations in EM routing. 3
    :param name: name.

    :return: (poses, activations).

    """

    def __init__(self, kernel_size=[5, 5], out_capsules=32, pose_shape=[3, 3], strides=2, batch_size=24, iterations=3, padding="same", name=''):
        super(ConvCaps, self).__init__(name=name)
        assert len(pose_shape) == 2, "The input Tensor should have shape=[W H]"

        self.out_capsules = out_capsules
        self.pose_shape = pose_shape
        self.strides = strides
        self.batch_size = batch_size
        self.iterations = iterations

    def build(self, input_shape):
        # Be sure to call this at the end
        # beta_v and beta_a one for each output capsule: (1, 1, 1, 32)
        self.beta_v = self.add_weight(
            name='beta_v', shape=[1, 1, 1, self.out_capsules], dtype=tf.float32,
            initializer='glorot_normal'
        )
        self.beta_a = self.add_weight(
            name='beta_a', shape=[1, 1, 1, self.out_capsules], dtype=tf.float32,
            initializer='glorot_normal'
        )
        super(ConvCaps, self).build(input_shape)

    def call(self, input_tensor):

        inputs_poses, inputs_activations = input_tensor
        pose_size = inputs_poses.get_shape()[-1]  # 4
        # Tile the input capusles' pose matrices to the spatial dimension of the output capsules
        # Such that we can later multiple with the transformation matrices to generate the votes.
        # (?, 14, 14, 32, 4 * 4) -> (?, 6, 6, 3x3=9, 32x16=512)
        inputs_poses = kernel_tile(inputs_poses, 3, self.strides)
        # Tile the activations needed for the EM routing
        # (?, 14, 14, 32) -> (?, 6, 6, 9, 32)
        inputs_activations = kernel_tile(inputs_activations, 3, self.strides)

        input_shape = int(inputs_activations.get_shape()[-1])  # by default 32

        spatial_size = int(inputs_activations.get_shape()[1])  # by default 6
        # Reshape it for later operations
        inputs_poses = tf.reshape(
            inputs_poses, shape=[-1, 3 * 3 * input_shape, 16])  # (?, 9x32=288, 16)
        inputs_activations = tf.reshape(
            inputs_activations, shape=[-1, spatial_size, spatial_size, 3 * 3 * input_shape])  # (?, 6, 6, 9x32=288)

        # Generate the votes by multiply it with the transformation matrices
        # (864, 288, 32, 16)
        votes = mat_transform(inputs_poses, self.out_capsules,
                            size=self.batch_size*spatial_size*spatial_size)
        # Reshape the vote for EM routing
        votes_shape = votes.get_shape()
        votes = tf.reshape(votes, shape=[self.batch_size, spatial_size, spatial_size,
                                        votes_shape[-3], votes_shape[-2], votes_shape[-1]])  # (24, 6, 6, 288, 32, 16)


        # Use EM routing to compute the pose and activation
        # votes (24, 6, 6, 3x3x32=288, 32, 16), inputs_activations (?, 6, 6, 288)
        # poses (24, 6, 6, 32, 16), activation (24, 6, 6, 32)
        poses, activations = em_routing(votes, inputs_activations, self.beta_v, self.beta_a, self.iterations, name='em_routing')

        # Reshape it back to 4x4 pose matrix
        poses_shape = poses.get_shape()
        # (24, 6, 6, 32, 4 * 4)
        poses = tf.reshape(
            poses, [
                poses_shape[0], poses_shape[1], poses_shape[2], poses_shape[3], pose_size
            ]
        )

        return poses, activations

mat_transform 函数具有w,您可以在下面看到函数代码

def mat_transform(input, output_cap_size, size):
    """Compute the vote.
    ```mat_transform``` extracts the transformation matrices parameters as a TensorFlow trainable variable \f$w\f$.
    It then multiplies with the “tiled” input pose matrices to generate the votes for the parent capsules.

    :param inputs: shape (size, 288, 16)
    :param output_cap_size: 32

    :return votes: (24, 5, 5, 3x3=9, 136)
    """

    caps_num_i = int(input.get_shape()[1])  # 288
    # (size, 288, 1, 4, 4)
    output = tf.reshape(input, shape=[size, caps_num_i, 1, 4, 4])

    w = tf.contrib.slim.variable('w', shape=[1, caps_num_i, output_cap_size, 4, 4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0))  # (1, 288, 32, 4, 4)
    w = K.tile(w, [size, 1, 1, 1, 1])  # (24, 288, 32, 4, 4)

    # (size, 288, 32, 4, 4)
    output = K.tile(output, [1, 1, output_cap_size, 1, 1])

    votes = tf.matmul(output, w)  # (24, 288, 32, 4, 4)
    # (size, 288, 32, 16)
    votes = tf.reshape(votes, [size, caps_num_i, output_cap_size, 16])

    return votes

此行或keras层需要进行哪些更改?

tf.contrib.slim.variable('w', shape=[1, caps_num_i, output_cap_size, 4, 4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0))

我如何在喀拉拉邦进行这项培训?

有关此代码的更多信息,您可以在https://jhui.github.io/2017/11/14/Matrix-Capsules-with-EM-routing-Capsule-Network/

上查看原始的tensorflow代码。

0 个答案:

没有答案