在Tensorflow中使用3d转置卷积进行上采样

时间:2018-01-18 18:25:12

标签: python tensorflow computer-vision deep-learning convolution

我在Tensorflow中定义了一个3D转置卷积,如下所示:

def weights(shape):
    return tf.Variable(tf.truncated_normal(shape, mean = 0.0, stddev=0.1))

def biases(shape):
    return tf.Variable(tf.constant(value = 0.1, shape = shape))

def trans_conv3d(x, W, output_shape, strides, padding):
    return tf.nn.conv3d_transpose(x, W, output_shape, strides, padding)

def transconv3d_layer(x, shape, out_shape, strides, padding):
   # shape: [depth, height, width, output_channels, in_channels].
   # output_shape: [batch, depth, height, width, output_channels]
    W = weights(shape)
    b = biases([shape[4]]) 
    return tf.nn.elu(trans_conv3d(x, W, out_shape, strides, padding) + b)

假设我的前一个图层{4}具有x[2, 1, 1, 1, 10]batch = 2depth = 1的形状为height = 1的4D张量width = 1,和in_channels = 10列出here

如何使用transconv3d_layerx在一系列图层上进行上采样,以获得最终形状,例如[2, 100, 100, 100, 10]或其他内容类似的?我不清楚如何通过转置层跟随张量的形状。

1 个答案:

答案 0 :(得分:1)

以下是如何使用它:

input = tf.random_normal(shape=[2, 1, 1, 1, 10])
deconv1 = transconv3d_layer(input,
                            shape=[2, 3, 3, 10, 10],
                            out_shape=[2, 50, 50, 50, 10],
                            strides=[1, 1, 1, 1, 1],
                            padding='SAME')
deconv2 = transconv3d_layer(deconv1,
                            shape=[2, 3, 3, 10, 10],
                            out_shape=[2, 100, 100, 100, 10],
                            strides=[1, 1, 1, 1, 1],
                            padding='SAME')
# deconv3 ...

print(deconv1)  # Tensor("Elu:0", shape=(2, 50, 50, 50, 10), dtype=float32)
print(deconv2)  # Tensor("Elu_1:0", shape=(2, 100, 100, 100, 10), dtype=float32)

基本上,您应将每个out_shape指定为您想要上传到(2, 50, 50, 50, 10)(2, 100, 100, 100, 10),...

input shape: [batch, depth, height, width, in_channels] filter shape: [depth, height, width, output_channels, in_channels] output shape: [batch, depth, height, width, output_channels]

为了清楚起见,以下是不同张量中尺寸的含义:

template <class T>
struct has_unique_keys : std::false_type {};

template <class... P>
struct has_unique_keys<std::set<P...>> : std::true_type {};

template <class... P>
struct has_unique_keys<std::map<P...>> : std::true_type {};

// ...