我已将模型从PyTorch转换为Keras,并使用后端提取张量流图。由于PyTorch的数据格式是NCHW,因此提取和保存的模型也是如此。在将模型转换为TFLite时,由于格式为NCHW,因此无法转换。有没有办法将整个图表转换为NHCW?
答案 0 :(得分:1)
最好让图表的数据格式与TFLite匹配,以加快推理速度。一种方法是手动将转置操作插入图形,如以下示例所示: How to convert the CIFAR10 tutorial to NCHW
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as session:
kernel = tf.ones(shape=[5, 5, 3, 64])
images = tf.ones(shape=[64,24,24,3])
imgs = tf.transpose(images, [0, 3, 1, 2]) # NHWC -> NCHW
conv = tf.nn.conv2d(imgs, kernel, [1, 1, 1, 1], padding='SAME', data_format = 'NCHW')
conv = tf.transpose(conv, [0, 2, 3, 1]) # NCHW -> NHWC
print("conv=",conv.eval())
答案 1 :(得分:0)
不幸的是,当前无法将NCHW图转换为NHWC。您必须从NHWC图开始进行训练,如果以后要使用TF lite进行训练。