有什么方法可以将.pb文件的数据格式从NCHW转换为NHWC?

时间:2020-02-03 23:40:34

标签: python-3.x tensorflow pytorch conv-neural-network pre-trained-model

我有一个CNN模型,该模型在Pytorch中根据数据格式 N(批)x C(通道)x H(高度)x W(宽度)进行了训练。我将预先训练的模型保存为 model.pth 。之后,我使用现有功能

model.pth-> model.onnx 转换了预训练模型。
torch.onnx.export(model, dummy_input, "model.onnx")

然后,我通过以下模块转换了 model.onnx-> model.pb

import onnx
from onnx_tf.backend import prepare 

model_onnx = onnx.load('model.onnx')
tf_rep = prepare(model_onnx)
tf_rep.export_graph('model.pb')

问题是:我想在需要 NHWC 数据格式的CPU设备上使用此model.pb。但是,我的模型基于NCHW数据格式。是否有任何方法可以将Model.pb的数据格式从NCHW转换为NHWC?

2 个答案:

答案 0 :(得分:1)

简短的回答,您处于困境中。

长答案,这是困难的,但仍然可能。导致您遇到问题的原因是您的图形已经过训练。创建训练图时,效率低下,但更容易转换NCHW-> NHWC。请参见类似的答案herehere

现在,您必须使用自定义卷积运算符来重载conv2D运算符。这是一个入门的伪代码。

   tensor Conv2D(X, W, B) {
     int perm[] = {0, 3, 1, 2};
     X = transposeTensor(X, perm);
     W = transposeTensor(W, perm);
     Y = Conv2D_orig(X, W, B, ...) ;
     perm = {0, 2, 3, 1};
     return transposeTensor(Y, perm);
   }

答案 1 :(得分:0)

您可以转置输入张量吗?例如input.transpose(1,2).transpose(2,3)吗?

>>> torch.randn( (3,3,3,3), names=['n','c','h','w']).transpose(1,2).transpose(2,3).names
('n', 'h', 'w', 'c')