经过训练的Tensorflow模型在推理上表现不佳

时间:2019-09-12 14:43:38

标签: python android tensorflow machine-learning keras

我使用带有Tensorflow后端的Keras训练了图像分类模型。该模型在验证数据集和测试数据上均具有良好的准确性,我将整个模型保存为.h5格式,这是我的检查点回调。

checkpoint = ModelCheckpoint(model_name+".h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)

由于我希望在Android上使用此模型,因此我使用keras_to_tensorflow.将模型重新冻结为二进制protobuf (.pb)

在移动设备上使用模型进行推理时,我注意到模型给出了非常错误且随机的预测。我曾尝试探索其他原因,为什么仍然可以像我发现here一样出现这种情况,看来问题显然不在于加载图像。

此外,在Tensorflow Python上对转换后的模型进行推断仍然会给出相同的错误/随机预测。这是我在Python中执行推理的代码。

def model_predict( model_path, image_path, model_input, model_output, class_names ):

    with tf.Graph().as_default() as graph: # Set default graph as graph

        with tf.Session() as sess:
            # Load the graph in graph_def
            print("load graph")

            # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
            with gfile.FastGFile(model_path,'rb') as f:

                print("Load Image...")
                # Read the image & get statstics
                np_image = Image.open(image_path)
                np_image = np.array(np_image).astype('float32')/255
                np_image = np.resize(np_image, (224, 224, 3))
                np_image = np.expand_dims(np_image, axis=0)


                # Set FCN graph to the default graph
                graph_def = tf.GraphDef()
                graph_def.ParseFromString(f.read())
                sess.graph.as_default()

                # Import a graph_def into the current default Graph (In this case, the weights are (typically) embedded in the graph)

                tf.import_graph_def(
                graph_def,
                input_map=None,
                return_elements=None,
                name="",
                op_dict=None,
                producer_op_list=None
                )

                # INFERENCE Here
                m_input = graph.get_tensor_by_name(model_input) # Input Tensor
                m_output = graph.get_tensor_by_name(model_output) # Output Tensor

                print ("Shape of input : ", tf.shape(m_input))
                #initialize_all_variables
                tf.global_variables_initializer()

                # Run model on single image
                Session_out = sess.run( m_output, feed_dict = {m_input : np_image} )

                print("Predicted class:", class_names[Session_out[0].argmax()] )

如何使用带有保存.pb模型的Tensorflow Python / Android执行推理?

其他人建议我保存用于训练的会话,并在执行推理时将其加载到Tensorflow。如果是这种情况,如何在Tensorflow android中加载已保存的会话?

我确信该模型不会过拟合,在使用Keras时,其性能非常好。

0 个答案:

没有答案