在tf.nn.conv2d中stride列表的确切含义是什么?

时间:2018-01-31 07:25:31

标签: tensorflow

conv2d( input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None)
strides=[b,h,w,c]

我知道b表示batchh表示heightw表示widthc表示{{1} }}。我看到channelb始终是1.如果cb = 2是什么意思?

1 个答案:

答案 0 :(得分:2)

Stride是您要在特定方向跳过的金额。您的每个批次都是4维(batch_size, height, width, channels)。但是,你知道计算不应该跳过任何batch,也不应该跳过任何channel,但是GPU看到的只是一个4D张量,因此请求stride每个维度。

tf.nn.conv2d是Tensorflow中的低级实现,它实际暴露了GPU API。还有另一个高级实现,tf.layers.Conv2d只允许您使用height stridewidth stride传递两个元素元组。但是,如果您想使用低级API(可能是由于对参数的更多控制),您应始终将批次和列的步幅保持为1。