我有一个带有功能的自定义层。功能具有w
训练权重。如何在下面的代码中使用它。
class ConvCaps(tf.keras.layers.Layer):
"""This constructs a convolution capsule layer from a primary or convolution capsule layer.
i: input capsules (32)
o: output capsules (32)
batch size: 24
spatial dimension: 14x14
kernel: 3x3
:param inputs: a primary or convolution capsule layer with poses and activations
pose: (24, 14, 14, 32, 4 * 4)
activation: (24, 14, 14, 32)
:param shape: the shape of convolution operation kernel, [kh, kw, i, o] = (3, 3, 32, 32)
:param strides: often [1, 2, 2, 1] (stride 2), or [1, 1, 1, 1] (stride 1).
:param iterations: number of iterations in EM routing. 3
:param name: name.
:return: (poses, activations).
"""
def __init__(self, kernel_size=[5, 5], out_capsules=32, pose_shape=[3, 3], strides=2, batch_size=24, iterations=3, padding="same", name=''):
super(ConvCaps, self).__init__(name=name)
assert len(pose_shape) == 2, "The input Tensor should have shape=[W H]"
self.out_capsules = out_capsules
self.pose_shape = pose_shape
self.strides = strides
self.batch_size = batch_size
self.iterations = iterations
def build(self, input_shape):
# Be sure to call this at the end
# beta_v and beta_a one for each output capsule: (1, 1, 1, 32)
self.beta_v = self.add_weight(
name='beta_v', shape=[1, 1, 1, self.out_capsules], dtype=tf.float32,
initializer='glorot_normal'
)
self.beta_a = self.add_weight(
name='beta_a', shape=[1, 1, 1, self.out_capsules], dtype=tf.float32,
initializer='glorot_normal'
)
super(ConvCaps, self).build(input_shape)
def call(self, input_tensor):
inputs_poses, inputs_activations = input_tensor
pose_size = inputs_poses.get_shape()[-1] # 4
# Tile the input capusles' pose matrices to the spatial dimension of the output capsules
# Such that we can later multiple with the transformation matrices to generate the votes.
# (?, 14, 14, 32, 4 * 4) -> (?, 6, 6, 3x3=9, 32x16=512)
inputs_poses = kernel_tile(inputs_poses, 3, self.strides)
# Tile the activations needed for the EM routing
# (?, 14, 14, 32) -> (?, 6, 6, 9, 32)
inputs_activations = kernel_tile(inputs_activations, 3, self.strides)
input_shape = int(inputs_activations.get_shape()[-1]) # by default 32
spatial_size = int(inputs_activations.get_shape()[1]) # by default 6
# Reshape it for later operations
inputs_poses = tf.reshape(
inputs_poses, shape=[-1, 3 * 3 * input_shape, 16]) # (?, 9x32=288, 16)
inputs_activations = tf.reshape(
inputs_activations, shape=[-1, spatial_size, spatial_size, 3 * 3 * input_shape]) # (?, 6, 6, 9x32=288)
# Generate the votes by multiply it with the transformation matrices
# (864, 288, 32, 16)
votes = mat_transform(inputs_poses, self.out_capsules,
size=self.batch_size*spatial_size*spatial_size)
# Reshape the vote for EM routing
votes_shape = votes.get_shape()
votes = tf.reshape(votes, shape=[self.batch_size, spatial_size, spatial_size,
votes_shape[-3], votes_shape[-2], votes_shape[-1]]) # (24, 6, 6, 288, 32, 16)
# Use EM routing to compute the pose and activation
# votes (24, 6, 6, 3x3x32=288, 32, 16), inputs_activations (?, 6, 6, 288)
# poses (24, 6, 6, 32, 16), activation (24, 6, 6, 32)
poses, activations = em_routing(votes, inputs_activations, self.beta_v, self.beta_a, self.iterations, name='em_routing')
# Reshape it back to 4x4 pose matrix
poses_shape = poses.get_shape()
# (24, 6, 6, 32, 4 * 4)
poses = tf.reshape(
poses, [
poses_shape[0], poses_shape[1], poses_shape[2], poses_shape[3], pose_size
]
)
return poses, activations
mat_transform 函数具有w
,您可以在下面看到函数代码
def mat_transform(input, output_cap_size, size):
"""Compute the vote.
```mat_transform``` extracts the transformation matrices parameters as a TensorFlow trainable variable \f$w\f$.
It then multiplies with the “tiled” input pose matrices to generate the votes for the parent capsules.
:param inputs: shape (size, 288, 16)
:param output_cap_size: 32
:return votes: (24, 5, 5, 3x3=9, 136)
"""
caps_num_i = int(input.get_shape()[1]) # 288
# (size, 288, 1, 4, 4)
output = tf.reshape(input, shape=[size, caps_num_i, 1, 4, 4])
w = tf.contrib.slim.variable('w', shape=[1, caps_num_i, output_cap_size, 4, 4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0)) # (1, 288, 32, 4, 4)
w = K.tile(w, [size, 1, 1, 1, 1]) # (24, 288, 32, 4, 4)
# (size, 288, 32, 4, 4)
output = K.tile(output, [1, 1, output_cap_size, 1, 1])
votes = tf.matmul(output, w) # (24, 288, 32, 4, 4)
# (size, 288, 32, 16)
votes = tf.reshape(votes, [size, caps_num_i, output_cap_size, 16])
return votes
此行或keras层需要进行哪些更改?
tf.contrib.slim.variable('w', shape=[1, caps_num_i, output_cap_size, 4, 4], dtype=tf.float32, initializer=tf.truncated_normal_initializer(mean=0.0, stddev=1.0))
我如何在喀拉拉邦进行这项培训?
有关此代码的更多信息,您可以在https://jhui.github.io/2017/11/14/Matrix-Capsules-with-EM-routing-Capsule-Network/
上查看原始的tensorflow代码。