TensorFlow模型对大型数据集

时间:2018-03-19 05:47:30

标签: python tensorflow machine-learning pyspark deep-learning

我正在使用TensorFlow for Poets来检测服装图像中的特征。我训练了4种不同的模特(袖子,形状,长度和底边)。现在我将图像URL传递给每个模型并存储结果。由于我拥有巨大的数据(100k图像),所以使用spark一次广播4个模型并通过图像RDD来检测功能。这是指数时间。它从3秒开始/图像&不断增加执行时间。当脚本已经感知到10k图像时,其执行时间达到8秒/图像。我是Tensorflow的新手,如果有任何想法让执行时间成为线性,我将非常感激。

def getLabelDresses(file_name):
    resultDict = {}
    t = read_tensor_from_image_file(file_name,
                              input_height=input_height,
                              input_width=input_width,
                              input_mean=input_mean,
                              input_std=input_std)
    input_name = "import/" + input_layer
    output_name = "import/" + output_layer



    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_hemline.value)

        tf.import_graph_def(graph_def)

        input_operation_hemline = g.get_operation_by_name(input_name);
        output_operation_hemline = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_hemline)
        resultDict['hemline'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_shape.value)

        tf.import_graph_def(graph_def)

        input_operation_shape = g.get_operation_by_name(input_name);
        output_operation_shape = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_shape)
        resultDict['shape'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_length.value)

        tf.import_graph_def(graph_def)

        input_operation_length = g.get_operation_by_name(input_name);
        output_operation_length = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_length)
        resultDict['length'] = labels[top_k[0]]

    with tf.Graph().as_default() as g:

        graph_def = tf.GraphDef()

        graph_def.ParseFromString(model_data_sleeve.value)

        tf.import_graph_def(graph_def)

        input_operation_sleeve = g.get_operation_by_name(input_name);
        output_operation_sleeve = g.get_operation_by_name(output_name);

        with tf.Session() as sess:
            results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})

        results = np.squeeze(results)

        top_k = results.argsort()[-1:][::-1]
        labels = load_labels(label_file_sleeve)
        resultDict['sleeve'] = labels[top_k[0]]     

    return resultDict;


model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"
label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"
model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"
label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"
model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"
label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"
model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"
label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"

with gfile.FastGFile(model_file_hemline, "rb") as f:
    model_data = f.read()
    model_data_hemline = sc.broadcast(model_data)
with gfile.FastGFile(model_file_length, "rb") as f:
    model_data = f.read()
    model_data_length = sc.broadcast(model_data)
with gfile.FastGFile(model_file_shape, "rb") as f:
    model_data = f.read()
    model_data_shape = sc.broadcast(model_data)
with gfile.FastGFile(model_file_sleeve, "rb") as f:
    model_data = f.read()
    model_data_sleeve = sc.broadcast(model_data)

def calculate(row):
    path = "/tmp/"+row.guid
    url = row.modelno
    print(path, url)
    if(url is not None):
        import urllib.request
        urllib.request.urlretrieve(url, path)
        t1=time.time() 
        result = getLabelDresses(path)
        print(time.time()-t1)
        print(result)
        return row
    return row

product2.rdd.map(calculate).collect()

1 个答案:

答案 0 :(得分:1)

代码中对getLabelDresses的每次调用都会向图表添加操作。

将代码拆分为设置(模型加载)部分,执行一次,并为每个图像执行执行部分。后者应仅包含对Session.run的调用。

另一种选择是在使用tf.reset_default_graph处理下一张图像之前清除图表。但它不太可取。