在张量流中,我找不到用周期性边界条件进行卷积(tf.nn.conv2d)的直接可能性。
E.g。采取张量
[[1,2,3],
[4,5,6],
[7,8,9]]
和任何3x3过滤器。具有周期性边界条件的卷积原则上可以通过周期性填充到5×5
来完成[[9,7,8,9,7],
[3,1,2,3,1],
[6,4,5,6,4],
[9,7,8,9,7],
[3,1,2,3,1]]
随后以“有效”模式与过滤器进行卷积。但是,遗憾的是,函数tf.pad不支持定期填充。
有一个简单的解决方法吗?
答案 0 :(得分:4)
以下内容适合您的情况:
import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6],[7,8,9]])
b = tf.tile(a, [3, 3])
result = b[2:7, 2:7]
sess = tf.InteractiveSession()
print(result.eval())
# prints the following
array([[9, 7, 8, 9, 7],
[3, 1, 2, 3, 1],
[6, 4, 5, 6, 4],
[9, 7, 8, 9, 7],
[3, 1, 2, 3, 1]], dtype=int32)
正如评论中所指出的,这在内存方面有点低效。如果内存对你来说是一个问题,但愿意花一些计算,那么下面的内容也会有效:
pre = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0]])
post = tf.transpose(pre)
result = tf.matmul(tf.matmul(pre, a), post)
print(result.eval())
答案 1 :(得分:1)
这是张量流中周期性填充的实现,适用于一批二维图像。它使用切片和tf.concat:
def periodic_padding(x, padding=1):
'''
x: shape (batch_size, d1, d2)
return x padded with periodic boundaries. i.e. torus or donut
'''
d1 = x.shape[1] # dimension 1: height
d2 = x.shape[2] # dimension 2: width
p = padding
# assemble padded x from slices
# tl,tc,tr
# padded_x = ml,mc,mr
# bl,bc,br
top_left = x[:, -p:, -p:] # top left
top_center = x[:, -p:, :] # top center
top_right = x[:, -p:, :p] # top right
middle_left = x[:, :, -p:] # middle left
middle_center = x # middle center
middle_right = x[:, :, :p] # middle right
bottom_left = x[:, :p, -p:] # bottom left
bottom_center = x[:, :p, :] # bottom center
bottom_right = x[:, :p, :p] # bottom right
top = tf.concat([top_left, top_center, top_right], axis=2)
middle = tf.concat([middle_left, middle_center, middle_right], axis=2)
bottom = tf.concat([bottom_left, bottom_center, bottom_right], axis=2)
padded_x = tf.concat([top, middle, bottom], axis=1)
return padded_x
import tensorflow as tf
a = tf.constant([
[[1,2,3],[4,5,6],[7,8,9]],
[[11,12,13],[14,15,16],[17,18,19]],
])
result = periodic_padding(a, padding=1)
sess = tf.InteractiveSession()
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
sess.close()
示例的输出为:
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 9 7 8 9 7]
[ 3 1 2 3 1]
[ 6 4 5 6 4]
[ 9 7 8 9 7]
[ 3 1 2 3 1]]
[[19 17 18 19 17]
[13 11 12 13 11]
[16 14 15 16 14]
[19 17 18 19 17]
[13 11 12 13 11]]]
答案 2 :(得分:0)
稍微更通用,更灵活:对一个或多个指定轴进行定期填充,并且可以为不同的轴指定不同的填充长度
import tensorflow as tf
def periodic_padding_flexible(tensor, axis,padding=1):
"""
add periodic padding to a tensor for specified axis
tensor: input tensor
axis: on or multiple axis to pad along, int or tuple
padding: number of cells to pad, int or tuple
return: padded tensor
"""
if isinstance(axis,int):
axis = (axis,)
if isinstance(padding,int):
padding = (padding,)
ndim = len(tensor.shape)
for ax,p in zip(axis,padding):
# create a slice object that selects everything from all axes,
# except only 0:p for the specified for right, and -p: for left
ind_right = [slice(-p,None) if i == ax else slice(None) for i in range(ndim)]
ind_left = [slice(0, p) if i == ax else slice(None) for i in range(ndim)]
right = tensor[ind_right]
left = tensor[ind_left]
middle = tensor
tensor = tf.concat([right,middle,left], axis=ax)
return tensor
a = tf.constant([
[[1,2,3],[4,5,6],[7,8,9]],
[[11,12,13],[14,15,16],[17,18,19]],
])
sess = tf.InteractiveSession()
result = periodic_padding_flexible(a, axis=1,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
result = periodic_padding_flexible(a, axis=2,padding=1)
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
result = periodic_padding_flexible(a, axis=(1,2),padding=(1,2))
print('a:')
print(a.eval())
print('padded a:')
print(result.eval())
输出:
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 7 8 9]
[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[ 1 2 3]]
[[17 18 19]
[11 12 13]
[14 15 16]
[17 18 19]
[11 12 13]]]
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 3 1 2 3 1]
[ 6 4 5 6 4]
[ 9 7 8 9 7]]
[[13 11 12 13 11]
[16 14 15 16 14]
[19 17 18 19 17]]]
a:
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]]
[[11 12 13]
[14 15 16]
[17 18 19]]]
padded a:
[[[ 8 9 7 8 9 7 8]
[ 2 3 1 2 3 1 2]
[ 5 6 4 5 6 4 5]
[ 8 9 7 8 9 7 8]
[ 2 3 1 2 3 1 2]]
[[18 19 17 18 19 17 18]
[12 13 11 12 13 11 12]
[15 16 14 15 16 14 15]
[18 19 17 18 19 17 18]
[12 13 11 12 13 11 12]]]