使用NCHW格式在tensorflow.nn.conv2d中过滤形状

时间:2017-03-14 10:20:31

标签: python tensorflow

关注Tensorflow's best practices for performance,我使用的是NCHW数据格式,但我不确定要在tensorflow.nn.conv2d中使用的滤镜形状。

该文档说使用[filter_height, filter_width, in_channels, out_channels]作为NHWC格式,但不清楚如何处理NCHW。

应该使用相同的形状吗?

1 个答案:

答案 0 :(得分:0)

使用相同的滤镜形状应该有效。对函数参数的唯一更改是步幅。举个例子,让我们说你希望你的架构可以使用这两种格式,这也是推荐的:

# input -> Tensor in NCHW format
if use_nchw:
    result = tf.nn.conv2d(
        input=input,
        filter=filter,
        strides=[1, 1, stride, stride],
        data_format='NCHW')
else:
    input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC

    result = tf.nn.conv2d(
        input=input_t,
        filter=filter,
        strides=[1, stride, stride, 1])

    result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW