首先,我想使用tf.reshape()
重塑2-D到4-D张量
我以为tf.reshape()
会改变
[batch, array]
- > [batch, width, height, channels]
(NHWC)命令
但在实践中它改变了
[batch, array]
- > [batch, channels, width, height]
(NCHW)订单
示例:
a = np.array([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16],[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]])
print(a.shape)
# [batch_size, channels, height, width]
b = sess.run(tf.reshape(a, shape=[2, 3, 4, 4]))
# [batch_size, height, width, channels]
c = sess.run(tf.reshape(a, shape=[2, 4, 4, 3]))
print(b)
print('*******')
print(c)
结果是:
(2, 48)
[[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]
[[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]
[[ 1 2 3 4]
[ 5 6 7 8]
[ 9 10 11 12]
[13 14 15 16]]]]
*******
[[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 1]
[ 2 3 4]]
[[ 5 6 7]
[ 8 9 10]
[11 12 13]
[14 15 16]]]
[[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
[[13 14 15]
[16 1 2]
[ 3 4 5]
[ 6 7 8]]
[[ 9 10 11]
[12 13 14]
[15 16 1]
[ 2 3 4]]
[[ 5 6 7]
[ 8 9 10]
[11 12 13]
[14 15 16]]]]
因此,我将data_format='channels_first'
更改为conv并将图层池化为使用NCHW顺序中的重新形成张量。事实上,培训很好。 --verbose:
它提供了更好的结果,正如here中@mrry所提到的那样,我认为这可以理解,因为NCHW是cuDNN的默认顺序。
但是,我无法使用tf.summary.image()
,which is documented here将图片添加到摘要中,因为所需的张量形状应该是[batch_size, height, width, channels]
顺序。
此外,如果我以[batch, width, height, channels]
顺序训练和可视化输入图像,则表示不正确的图像。
值得一提的是,培训结果不如使用[batch, channels, width, height]
订单那么好。
有几个问题:
的 1。为什么 tf.reshape()
转换[batch , array]
- > (NCHW)订单而不是(NHWC)订单?我用tf CPU和GPU测试,结果相同。我也使用了np.reshape(),结果也一样。 (这就是为什么我认为我可以在这里误解一些东西)
的 2。如何可以使用(NCHW)顺序的tf.summary.image()在tensorboard中可视化图像? (问题#2解决使用来自 @Maosi Chen 的建议。谢谢)
我已经在 GPU(版本1.4)上训练了模型,图像来自 CIFAR-10 数据集。
谢谢
答案 0 :(得分:2)
您可以按tf.transpose
(https://www.tensorflow.org/api_docs/python/tf/transpose)重新排序维度。
请注意,perm
元素是源张量的维度索引(a
)
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
a = np.array([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16],[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]])
print(a.shape)
# [batch_size, channels, height, width]
b = sess.run(tf.reshape(a, shape=[2, 3, 4, 4]))
# [batch_size, height, width, channels]
c = sess.run(tf.transpose(b, perm=[0, 2, 3, 1]))
print(b)
print('*******')
print(c)
结果:
(2, 48) [[[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]
[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]
[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]]
[[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]
[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]
[[ 1 2 3 4] [ 5 6 7 8] [ 9 10 11 12] [13 14 15 16]]]]
******* [[[[ 1 1 1] [ 2 2 2] [ 3 3 3] [ 4 4 4]]
[[ 5 5 5] [ 6 6 6] [ 7 7 7] [ 8 8 8]]
[[ 9 9 9] [10 10 10] [11 11 11] [12 12 12]]
[[13 13 13] [14 14 14] [15 15 15] [16 16 16]]]
[[[ 1 1 1] [ 2 2 2] [ 3 3 3] [ 4 4 4]]
[[ 5 5 5] [ 6 6 6] [ 7 7 7] [ 8 8 8]]
[[ 9 9 9] [10 10 10] [11 11 11] [12 12 12]]
[[13 13 13] [14 14 14] [15 15 15] [16 16 16]]]]