如何为定制运营商自动实现NCHW格式?

时间:2018-08-28 17:03:52

标签: tensorflow

我构建了一个启用GPU的Tensorflow。当我将NHWC图像(或特征图)提供给tf.nn.conv2d时,我发现在相应的C ++内核Conv2DOp :: Compute()中,它自动将NHWC转换为NCHW。我也 定义了一个新的自定义运算符,该运算符接受要素图作为输入:

REGISTER_OP("RegionOp")
.Input("value: T")
.Output("output: T")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(......);

但是,我在自定义的运算符内核RegionOp :: Compute()中发现,要素映射格式始终为NHWC。

是否可以像tf.nn.conv2d(Conv2DOp)一样在新的自定义运算符中将NHWC自动转换为NCHW?

非常感谢您!

0 个答案:

没有答案