ValueError:GraphDef不能大于2GB

时间:2017-06-13 12:00:25

标签: python image-processing tensorflow

我正在使用tensorflow的imageNet训练模型对多个类别的图像进行分类。

我将脚本classify.py编辑为

import tensorflow as tf
import sys
import glob
import os
import pandas as pd

# Disable tensorflow compilation warnings
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf

test_path = '/Users/kaustubhmundra/Desktop/Multi-Class Classifier/test'

classes = ['room','reception','washroom','facade']

result = pd.DataFrame(columns = ['facade','washroom','room','reception'])

def predict(image_path):
    #image_path = sys.argv[1]

    # Read the image_data
    image_data = tf.gfile.FastGFile(image_path, 'rb').read()

    # Loads label file, strips off carriage return
    label_lines = [line.rstrip() for line 
                       in tf.gfile.GFile("tf_files/retrained_labels.txt")]

    # Unpersists graph from file
    with tf.gfile.FastGFile("tf_files/retrained_graph.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(graph_def, name='')

    with tf.Session() as sess:
        # Feed the image_data as input to the graph and get first prediction
        softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

        predictions = sess.run(softmax_tensor, \
                 {'DecodeJpeg/contents:0': image_data})

        # print(predictions)

        pred = pd.DataFrame(predictions,columns = ['facade','washroom','room','reception'])

        # print(pred)

        global result

        result = result.append(pred)

        # print(result)

        # Sort to show labels of first prediction in order of confidence
        top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

        for node_id in top_k:
            human_string = label_lines[node_id]
            score = predictions[0][node_id]
            print('%s (score = %.5f)' % (human_string, score))



path = os.path.join(test_path, '*')
files = sorted(glob.glob(path))

i=1

for fl in files:
    print(i)
    i = i + 1
    predict(fl)

result.to_csv('predictions.csv')

虽然我使用它来预测图像,但它可以完美地工作直到24个图像,但随后显示错误:

  

文件   " /Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/ops.py" ;,   第2154行,在_as_graph_def中       提高ValueError(" GraphDef不能大于2GB。")ValueError:GraphDef不能大于2GB。

如何解决此问题?

1 个答案:

答案 0 :(得分:0)

每次调用predict()时都会导入图表,因此您需要累积一个非常大的默认graphdef。您应该更改代码,以便只在预测函数之外加载图表(文件'部分中的#Unpersists图表)。这也应该大大加快你的代码。