如何在TensorFlow中使用tf.strided_slice()更改形状?

时间:2018-01-07 19:06:46

标签: python tensorflow

举个例子:我们有两个尺寸为2x2x3像素的RGB图像。四个像素中的每一个由3个间隔表示。插图可用作2D阵列。前4个中间值表示红色值,接下来的4个中间值表示绿色值,后4个中间值表示4个像素的蓝色值。

图片1:

[11, 12, 13, 14, 15, 16, 17, 18, 19, 191, 192, 193]

图片2:

[21, 22, 23, 24, 25, 26, 27, 28, 29, 291, 292, 293]

在TensorFlow中,这两个图像存储在Tensor

img_tensor = tf.constant([[11, 12, 13, 14, 15, 16, 17, 18, 19, 191, 192, 193],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 291, 292, 293]])
# Tensor("Const_10:0", shape=(2, 12), dtype=int32)

使用tf.strided_slice()之后我想要以下格式:

[[[[11, 15, 19],
[12, 16, 191]],
[[13, 17, 192],
[14, 18, 193]]],
[[[21, 25, 29],
[22, 26, 291]],
[[23, 27, 292],
[24, 28, 293]]]]
# Goal is: Tensor("...", shape=(2, 2, 2, 3), dtype=int32)

到目前为止我尝试过:

new_img_tensor = tf.strided_slice(img_tensor, [0, 0], [3, -1], [1, 4])

但结果不完整:

[[11 15 19]
 [21 25 29]]
# Tensor("StridedSlice_2:0", shape=(2, 3), dtype=int32)

有没有办法使用tf.strided_slice()

将尺寸从2D更改为4-D

1 个答案:

答案 0 :(得分:1)

似乎您需要reshape + transpose而不是strided_slice

tf.InteractiveSession()
tf.transpose(tf.reshape(img_tensor, (2, 3, 2, 2)), (0, 2, 3, 1)).eval()

#array([[[[ 11,  15,  19],
#         [ 12,  16, 191]],

#        [[ 13,  17, 192],
#         [ 14,  18, 193]]],


#       [[[ 21,  25,  29],
#         [ 22,  26, 291]],

#        [[ 23,  27, 292],
#         [ 24,  28, 293]]]])