张量流周期性填充

时间:2016-08-22 20:47:55

标签: python tensorflow

在张量流中,我找不到用周期性边界条件进行卷积(tf.nn.conv2d)的直接可能性。

E.g。采取张量

[[1,2,3],
 [4,5,6],
 [7,8,9]]

和任何3x3过滤器。具有周期性边界条件的卷积原则上可以通过周期性填充到5×5

来完成
[[9,7,8,9,7],
 [3,1,2,3,1],
 [6,4,5,6,4],
 [9,7,8,9,7],
 [3,1,2,3,1]]

随后以“有效”模式与过滤器进行卷积。但是,遗憾的是,函数tf.pad不支持定期填充。

有一个简单的解决方法吗?

3 个答案:

答案 0 :(得分:4)

以下内容适合您的情况:

import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
b = tf.tile(a, [3, 3])
result = b[2:7, 2:7]
sess = tf.InteractiveSession()
print(result.eval())

# prints the following 
array([[9, 7, 8, 9, 7],
       [3, 1, 2, 3, 1],
       [6, 4, 5, 6, 4],
       [9, 7, 8, 9, 7],
       [3, 1, 2, 3, 1]], dtype=int32)

正如评论中所指出的,这在内存方面有点低效。如果内存对你来说是一个问题,但愿意花一些计算,那么下面的内容也会有效:

pre = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
post = tf.transpose(pre)
result = tf.matmul(tf.matmul(pre, a), post)
print(result.eval())

答案 1 :(得分:1)

这是张量流中周期性填充的实现,适用于一批二维图像。它使用切片和tf.concat:

def periodic_padding(x, padding=1):
    '''
    x: shape (batch_size, d1, d2)
    return x padded with periodic boundaries. i.e. torus or donut
    '''
    d1 = x.shape[1] # dimension 1: height
    d2 = x.shape[2] # dimension 2: width
    p = padding
    # assemble padded x from slices
    #            tl,tc,tr
    # padded_x = ml,mc,mr
    #            bl,bc,br
    top_left = x[:, -p:, -p:] # top left
    top_center = x[:, -p:, :] # top center
    top_right = x[:, -p:, :p] # top right
    middle_left = x[:, :, -p:] # middle left
    middle_center = x # middle center
    middle_right = x[:, :, :p] # middle right
    bottom_left = x[:, :p, -p:] # bottom left
    bottom_center = x[:, :p, :] # bottom center
    bottom_right = x[:, :p, :p] # bottom right
    top = tf.concat([top_left, top_center, top_right], axis=2)
    middle = tf.concat([middle_left, middle_center, middle_right], axis=2)
    bottom = tf.concat([bottom_left, bottom_center, bottom_right], axis=2)
    padded_x = tf.concat([top, middle, bottom], axis=1)
    return padded_x

import tensorflow as tf
a = tf.constant([
    [[1,2,3],[4,5,6],[7,8,9]],
    [[11,12,13],[14,15,16],[17,18,19]],
])
result = periodic_padding(a, padding=1)
sess = tf.InteractiveSession()
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
sess.close()

示例的输出为:

a:
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]

 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 9  7  8  9  7]
  [ 3  1  2  3  1]
  [ 6  4  5  6  4]
  [ 9  7  8  9  7]
  [ 3  1  2  3  1]]

 [[19 17 18 19 17]
  [13 11 12 13 11]
  [16 14 15 16 14]
  [19 17 18 19 17]
  [13 11 12 13 11]]]

答案 2 :(得分:0)

稍微更通用,更灵活:对一个或多个指定轴进行定期填充,并且可以为不同的轴指定不同的填充长度

import tensorflow as tf

def periodic_padding_flexible(tensor, axis,padding=1):
    """
        add periodic padding to a tensor for specified axis
        tensor: input tensor
        axis: on or multiple axis to pad along, int or tuple
        padding: number of cells to pad, int or tuple

        return: padded tensor
    """


    if isinstance(axis,int):
        axis = (axis,)
    if isinstance(padding,int):
        padding = (padding,)

    ndim = len(tensor.shape)
    for ax,p in zip(axis,padding):
        # create a slice object that selects everything from all axes,
        # except only 0:p for the specified for right, and -p: for left

        ind_right = [slice(-p,None) if i == ax else slice(None) for i in range(ndim)]
        ind_left = [slice(0, p) if i == ax else slice(None) for i in range(ndim)]
        right = tensor[ind_right]
        left = tensor[ind_left]
        middle = tensor
        tensor = tf.concat([right,middle,left], axis=ax)

    return tensor



a = tf.constant([
    [[1,2,3],[4,5,6],[7,8,9]],
    [[11,12,13],[14,15,16],[17,18,19]],
])

sess = tf.InteractiveSession()

result = periodic_padding_flexible(a, axis=1,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())

result = periodic_padding_flexible(a, axis=2,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())

result = periodic_padding_flexible(a, axis=(1,2),padding=(1,2))
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())

输出:

a:
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 7  8  9]
  [ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]
  [ 1  2  3]]
 [[17 18 19]
  [11 12 13]
  [14 15 16]
  [17 18 19]
  [11 12 13]]]
a:
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 3  1  2  3  1]
  [ 6  4  5  6  4]
  [ 9  7  8  9  7]]
 [[13 11 12 13 11]
  [16 14 15 16 14]
  [19 17 18 19 17]]]
a:
[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]
 [[11 12 13]
  [14 15 16]
  [17 18 19]]]
padded a:
[[[ 8  9  7  8  9  7  8]
  [ 2  3  1  2  3  1  2]
  [ 5  6  4  5  6  4  5]
  [ 8  9  7  8  9  7  8]
  [ 2  3  1  2  3  1  2]]
 [[18 19 17 18 19 17 18]
  [12 13 11 12 13 11 12]
  [15 16 14 15 16 14 15]
  [18 19 17 18 19 17 18]
  [12 13 11 12 13 11 12]]]