使用TensorFlow自定义图像上采样

时间:2019-03-18 14:52:04

标签: python tensorflow

在TensorFlow中实现图层功能时遇到麻烦。也许有更多经验的人知道如何解决这个问题。该函数的用法应如下:

在:一个[B x W x H x 2]张量A

输出:一个名为B的新张量,大小为[B x p*W x q*W],填充如下:

for b from 0 to B: #loop over batches
    for w from 0 to W: # loop over width
        for h from 0 to H: # loop over height
            B[b,w*p:w*p+p,h*q:h*q+q] = tf.random.normal(shape=[p,q],
                                                        mean=A[b,w,h,0],
                                                        stddev=A[b,w,h,1])

我基本上想要做的是使用“随机(高斯)插值”对图像进行上采样。

我无法创建空张量并填充它,就像我通常根据伪代码所做的那样。我尝试使用的是TensorFlows tf.map_fn()函数,很不幸,该函数无法正常工作。

想法是稍后使用此层作为均值或最大池的替代方法。

也许有更简单的方法可以做到这一点?

任何帮助表示赞赏。谢谢。

1 个答案:

答案 0 :(得分:0)

您可以通过矢量化的方式(比循环或映射要快得多)来做到这一点,像这样:

import tensorflow as tf
import numpy as np

def gaussian_upsampling(A, p, q):
    s = tf.shape(A)
    B, W, H, C = s[0], s[1], s[2], s[3]
    # Add two dimensions to A for tiling
    A_exp = tf.expand_dims(tf.expand_dims(A, 2), 4)
    # Tile A along new dimensions
    A_tiled = tf.tile(A_exp, [1, 1, p, 1, q, 1])
    # Reshape
    A_tiled = tf.reshape(A_tiled, [B, W * p, H * q, C])
    # Extract mean and std
    mean_tiled = A_tiled[:, :, :, 0]
    std_tiled = A_tiled[:, :, :, 1]
    # Make base random value
    rnd = tf.random.normal(shape=[B, W * p, H * q], mean=0, stddev=1, dtype=A.dtype)
    # Scale and shift random value
    return rnd * std_tiled + mean_tiled

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    tf.random.set_random_seed(100)
    mean = tf.constant([[[ 1.0,  2.0,  3.0],
                         [ 4.0,  5.0,  6.0]],
                        [[ 7.0,  8.0,  9.0],
                         [10.0, 11.0, 12.0]]])
    std = tf.constant([[[0.1, 0.2, 0.3],
                        [0.4, 0.5, 0.6]],
                       [[0.7, 0.8, 0.9],
                        [1.0, 1.1, 1.2]]])
    A = tf.stack([mean, std], axis=-1)
    with np.printoptions(precision=2, suppress=True):
        print(sess.run(gaussian_upsampling(A, 3, 2)))

输出:

[[[ 0.94  0.97  1.82  1.67  2.89  2.96]
  [ 1.04  0.78  2.23  2.02  2.95  3.04]
  [ 0.9   0.96  1.84  1.98  2.74  3.06]
  [ 3.89  4.12  5.72  4.32  6.02  5.7 ]
  [ 3.47  4.27  4.39  4.85  6.38  5.32]
  [ 3.21  3.98  4.64  4.31  5.72  5.96]]

 [[ 8.15  7.08  7.33  7.78  8.75  9.95]
  [ 7.37  7.29  8.27  8.26  8.56  8.17]
  [ 5.91  7.95  7.9   7.81  8.43  8.64]
  [11.12 11.49 11.95 11.74 11.43 12.3 ]
  [ 9.98  9.66  9.21 10.2  12.78 12.13]
  [ 8.33 10.37 11.88 11.44 12.96 11.73]]]