Tensor Flow使用旧图而不是新图

时间:2017-07-14 17:47:51

标签: python tensorflow

我使用retrain.py重新训练了两种不同的分类模型。

为了预测两张图片的标签,我从Label_image.py创建了getLabel方法,如下所示:

def getLabel(localFile, graphKey, labelKey):

    image_data_str = tf.gfile.FastGFile(localFile, 'rb').read()

    # Loads label file, strips off carriage return
    label_lines = [line.rstrip() for line
                   in tf.gfile.GFile(labelKey)]

    # Unpersists graph from file
    with tf.gfile.FastGFile(graphKey, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

    sess = tf.Session()
    with sess:
        # Feed the image_data as input to the graph and get first prediction
        softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

        predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data_str})

        # Sort to show labels of first prediction in order of confidence
        top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
        series = []
        count = 1
        for node_id in top_k:
            human_string = label_lines[node_id]
            if count==1:
                label = human_string
                count+=1
            score = predictions[0][node_id]
            print('%s (score = %.5f)' % (human_string, score))
            series.append({"name": human_string, "data": [score * 100]})
        sess.close()

    return label, series

我称之为

 label,series = predict.getLabel(localFile, 'graph1.pb', 'labels1.txt')
 label,series = predict.getLabel(localFile, 'graph2.pb', 'labels2.txt')

但对于第二个函数调用,它使用的是旧图形,即graph1.pb&它给出了以下误差,因为模型1比模型2有更多的类别。

human_string = label_lines[node_id]
IndexError: list index out of range

我无法理解为什么会这样。有人可以告诉如何加载第二个图吗?

1 个答案:

答案 0 :(得分:0)

看起来正在发生的事情是,您正在为predict1.getFinalLabel的两次调用呼叫同一会话。您应该做的是定义两个单独的会话,并分别初始化每个会话(例如,有predict2.getFinalLabel(change))。如果您发布更多代码,我可以提供更多详细信息和代码。