举个例子:我们有两个尺寸为2x2x3像素的RGB图像。四个像素中的每一个由3个间隔表示。插图可用作2D阵列。前4个中间值表示红色值,接下来的4个中间值表示绿色值,后4个中间值表示4个像素的蓝色值。
图片1:
[11, 12, 13, 14, 15, 16, 17, 18, 19, 191, 192, 193]
图片2:
[21, 22, 23, 24, 25, 26, 27, 28, 29, 291, 292, 293]
在TensorFlow中,这两个图像存储在Tensor
中img_tensor = tf.constant([[11, 12, 13, 14, 15, 16, 17, 18, 19, 191, 192, 193],
[21, 22, 23, 24, 25, 26, 27, 28, 29, 291, 292, 293]])
# Tensor("Const_10:0", shape=(2, 12), dtype=int32)
使用tf.strided_slice()
之后我想要以下格式:
[[[[11, 15, 19],
[12, 16, 191]],
[[13, 17, 192],
[14, 18, 193]]],
[[[21, 25, 29],
[22, 26, 291]],
[[23, 27, 292],
[24, 28, 293]]]]
# Goal is: Tensor("...", shape=(2, 2, 2, 3), dtype=int32)
到目前为止我尝试过:
new_img_tensor = tf.strided_slice(img_tensor, [0, 0], [3, -1], [1, 4])
但结果不完整:
[[11 15 19]
[21 25 29]]
# Tensor("StridedSlice_2:0", shape=(2, 3), dtype=int32)
有没有办法使用tf.strided_slice()
答案 0 :(得分:1)
似乎您需要reshape
+ transpose
而不是strided_slice
:
tf.InteractiveSession()
tf.transpose(tf.reshape(img_tensor, (2, 3, 2, 2)), (0, 2, 3, 1)).eval()
#array([[[[ 11, 15, 19],
# [ 12, 16, 191]],
# [[ 13, 17, 192],
# [ 14, 18, 193]]],
# [[[ 21, 25, 29],
# [ 22, 26, 291]],
# [[ 23, 27, 292],
# [ 24, 28, 293]]]])