我想将3D卷积分为(2 + 1)D,这意味着先进行空间卷积,然后再进行时间卷积,即kernel =(D,H,W)-> kernel(1,H,W )和kernel(D,1,1)。从理论上讲,它具有较少的参数,并且应该节省内存。例如。具有3通道输入和64输出通道的内核(7x7x7),#params = 7x7x7x3x64 = 65,856;可分离的3d:1x7x7x3x64 + 7x1x1x64x64 = 38,080。
但是,在我的测试中,unit_sep3D
比unit3D
引起的内存更多。为什么会这样?
def unit3D(inputs, output_channels,
kernel_shape=(1, 1, 1),
strides=(1, 1, 1),
activation_fn=tf.nn.relu,
use_batch_norm=True,
use_bias=False,
padding='same',
is_training=True,
name=None):
"""Basic unit containing Conv3D + BatchNorm + non-linearity."""
with tf.variable_scope(name, 'unit3D', [inputs]):
net = tf.layers.conv3d(inputs, filters=output_channels,
kernel_size=kernel_shape,
strides=strides,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
net = tf.contrib.layers.batch_norm(net, is_training=is_training)
if activation_fn is not None:
net = activation_fn(net)
return net
def sep3D(inputs, output_channels,
kernel_shape=(1, 1, 1),
strides=(1, 1, 1),
activation_fn=tf.nn.relu,
use_batch_norm=True,
use_bias=False,
padding='same',
is_training=True,
name=None):
"""Basic Sep-Conv3D layer with BatchNorm + non-linearity.
A (k_t, k, k) kernel is replaced by a (1, k, k) kernel and a (k_t, 1, 1) kernel
"""
k_t, k_h, k_w = kernel_shape
if type(strides) == int:
s_t, s_h, s_w = strides, strides, strides
else:
s_t, s_h, s_w = strides
spatial_kernel = (1, k_h, k_w)
spatial_stride = (1, s_h, s_w)
temporal_kernel = (k_t, 1, 1)
temporal_stride = (s_t, 1, 1)
with tf.variable_scope(name, 'sep3D', [inputs]):
spatial_net = tf.layers.conv3d(inputs, filters=output_channels,
kernel_size=spatial_kernel,
strides=spatial_stride,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
spatial_net = tf.contrib.layers.batch_norm(spatial_net, is_training=is_training)
if activation_fn is not None:
spatial_net = activation_fn(spatial_net)
temporal_net = tf.layers.conv3d(spatial_net, filters=output_channels,
kernel_size=temporal_kernel,
strides=temporal_stride,
padding=padding,
use_bias=use_bias)
if use_batch_norm:
temporal_net = tf.contrib.layers.batch_norm(temporal_net, is_training=is_training)
if activation_fn is not None:
net = activation_fn(temporal_net)
return net