TFLite-在Android上将经过训练的PB文件转换为tflite失败的预测

时间:2018-07-17 06:55:38

标签: android python tensorflow

问题描述:

训练一个pb文件,然后转换为tflite文件,我使用python测试pb文件是正确的,但是在android上测试成tflite文件后却是错误的,我认为这个问题应该与BN(tf.nn.batch_normalization ),因为当我删除BN时,android和python的结果都是相同的,但由于BN不同,因此BN是tflite支持,并且输入数据相同,所以我不知道为什么?

测试演示:

def save_to_pb(sess):
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['out'])
    with tf.gfile.FastGFile(MODEL_DIR + 'expert-graph.pb', mode='wb') as f:
       f.write(constant_graph.SerializeToString())

def read_data(session):
   image_name = '182133.jpg'
    TEST_IMAGINE_PATH = '/home/leve/lcr/ClelebA/Celebra_crop_20w+_w128_dataset/images/test_100/'
    image = cv.imread(TEST_IMAGINE_PATH + image_name)
    image = tf.reshape(image, [128, 128, 3])
    image = image.eval(session=session)
    image = image[np.newaxis, :]
    return image

def batch_norm_lite(x, train=True, bn_decay=0.5,epsilon = 0.001,name='bn'):
    is_training = tf.convert_to_tensor(train,dtype='bool',name='is_training')
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]
    axis = list(range(len(x_shape) - 1))
    beta = tf.get_variable(name+'_beta', params_shape, initializer=tf.zeros_initializer())
    gamma = tf.get_variable(name+'_gamma', params_shape, initializer=tf.ones_initializer())
    moving_mean = tf.get_variable(name+'_moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
    moving_variance = tf.get_variable(name+'_moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)
    mean, variance = tf.nn.moments(x, axis)
    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, bn_decay)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, bn_decay)
    tf.add_to_collection(name+'_update_moving_mean', update_moving_mean)
    tf.add_to_collection(name+'_update_moving_variance', update_moving_variance)
    mean, variance = control_flow_ops.cond(
        is_training, lambda: (mean, variance),
        lambda: (moving_mean, moving_variance))

   return tf.nn.batch_normalization(x, mean, variance, beta, gamma, epsilon,name=name)

def train_model():
    img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 128, 128, 3))
    b = tf.Variable(tf.truncated_normal((1, 128, 128, 3), seed=1),name='w1')
    y_real = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
    val = tf.add(img, b)
    val = batch_norm_lite(val)
    out = tf.identity(val, name="out")
    MSE = tf.reduce_mean(tf.square(y_real - out), name='mse')
    train_step = tf.train.GradientDescentOptimizer(0.9).minimize(MSE)
    saver = tf.train.Saver(max_to_keep=10)
    os.environ["CUDA_VISIBLE_DEVICES"] = ''
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.1
    config.gpu_options.allow_growth = True
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for epoch in range(100):
            sess.run(train_step,feed_dict={img:read_data(sess)})
            if epoch % 20 == 0:
                save_path = saver.save(sess, MODEL_DIR + MODEL_NAME, global_step=epoch + 1)
            print('step = ' + str(epoch))
        save_to_pb(sess)
train_model()

运行上层代码并获取pb文件,当您使用python测试pb文件时,您可能会得到如下结果: 39.09046 45.927864 84.797905 45.952957 62.701195 68.00796 41.789146 80.99372 81.67459 81.2203 ....

然后运行此命令以创建tflite文件(未运行bazel Frozen_graph命令,因为有太多错误#20825#20824):

bazel run --config=opt tensorflow/contrib/lite/toco:toco -- --input_file=/path/expert-graph.pb --output_file=/path/expert-graph.tflite --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --input_shape=1,128,128,3 --input_array=img --output_array=out --inference_type=FLOAT --input_data_type=FLOAT --allow_custom_ops

您将获得expert-graph.tflite,然后将tflite文件用于android,将得到如下结果: 49.392387,42.290344,135.30707,7.7764254,125.98051,79.317825,157.29184,107.58522,283.46997,36.103508 ....

Android和python具有相同的输入数据,但是结果不同,如果我删除BN并进行训练,那还可以,但是BN非常重要,我可以删除它,我不知道如何解决它,这个问题已经存在了很长时间,请帮助解决它,如果不使用TF Lite,我该怎么办?谢谢

系统信息 Linux Ubuntu 16.04: 从源码安装TensorFlow: TensorFlow版本1.8.0: Python版本:3.6: Bazel版本0.11.1: GCC /编译器版本5.4.0 20160609(Ubuntu 5.4.0-6ubuntu1〜16.04.10): CUDA / cuDNN 9.0 / 7.0.5: GPU 1060-6G:

0 个答案:

没有答案