如何将CIFAR10教程转换为NCHW

时间:2018-01-08 15:47:09

标签: tensorflow

我试图将Tensorflow CIFAR10教程从NHWC转换为NCHW,但无法弄清楚如何这样做。我只找到了诸如this之类的答案,这是几行代码,但没有解释它是如何工作的以及在何处使用它。以下是使用this approach进行的几次尝试失败:

def inference(images):

    with tf.variable_scope('conv1') as scope:
    kernel = _variable_with_weight_decay('weights',
                                     shape=[5, 5, 3, 64],
                                     stddev=5e-2,
                                     wd=0.0)

    # ****************************************************************** #

    ### Original
    conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')

    ### Attempt 1
    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    ### Attempt 2
    kern = tf.transpose(kernel, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(images, kern, [1, 1, 1, 1], padding='SAME')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    # ****************************************************************** #

    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
    pre_activation = tf.nn.bias_add(conv, biases)
    conv1 = tf.nn.relu(pre_activation, name=scope.name)
    _activation_summary(conv1)

    ...

分别得到错误:

  

ValueError:尺寸必须相等,但对于' conv1 / Conv2D'是24和3。 (op:' Conv2D')输入形状:[64,3,24,24],[5,5,3,64]。

     

ValueError:尺寸必须相等,但对于' conv1 / Conv2D'是3和5 (op:' Conv2D')输入形状:[64,24,24,3],[5,64,5,3]。

有人可以提供一系列我可以遵循的步骤,将此示例成功转换为NCHW。

1 个答案:

答案 0 :(得分:2)

在您尝试#1时,请尝试以下操作:

conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')

(即将data_format = 'NCHW'添加到参数中)

e.g。如下:

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with tf.Session(config=config) as session:

    kernel = tf.ones(shape=[5, 5, 3, 64])
    images = tf.ones(shape=[64,24,24,3])

    imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
    conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
    conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC

    print("conv=",conv.eval())