我正在使用TensorFlow for Poets来检测服装图像中的特征。我训练了4种不同的模特(袖子,形状,长度和底边)。现在我将图像URL传递给每个模型并存储结果。由于我拥有巨大的数据(100k图像),所以使用spark一次广播4个模型并通过图像RDD来检测功能。这是指数时间。它从3秒开始/图像&不断增加执行时间。当脚本已经感知到10k图像时,其执行时间达到8秒/图像。我是Tensorflow的新手,如果有任何想法让执行时间成为线性,我将非常感激。
def getLabelDresses(file_name):
resultDict = {}
t = read_tensor_from_image_file(file_name,
input_height=input_height,
input_width=input_width,
input_mean=input_mean,
input_std=input_std)
input_name = "import/" + input_layer
output_name = "import/" + output_layer
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_hemline.value)
tf.import_graph_def(graph_def)
input_operation_hemline = g.get_operation_by_name(input_name);
output_operation_hemline = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_hemline.outputs[0],{input_operation_hemline.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_hemline)
resultDict['hemline'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_shape.value)
tf.import_graph_def(graph_def)
input_operation_shape = g.get_operation_by_name(input_name);
output_operation_shape = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_shape.outputs[0],{input_operation_shape.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_shape)
resultDict['shape'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_length.value)
tf.import_graph_def(graph_def)
input_operation_length = g.get_operation_by_name(input_name);
output_operation_length = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_length.outputs[0],{input_operation_length.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_length)
resultDict['length'] = labels[top_k[0]]
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_data_sleeve.value)
tf.import_graph_def(graph_def)
input_operation_sleeve = g.get_operation_by_name(input_name);
output_operation_sleeve = g.get_operation_by_name(output_name);
with tf.Session() as sess:
results = sess.run(output_operation_sleeve.outputs[0],{input_operation_sleeve.outputs[0]: t})
results = np.squeeze(results)
top_k = results.argsort()[-1:][::-1]
labels = load_labels(label_file_sleeve)
resultDict['sleeve'] = labels[top_k[0]]
return resultDict;
model_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_graph_hemline.pb"
label_file_hemline = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/hemline/retrained_labels_hemline.txt"
model_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_graph_length.pb"
label_file_length = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/length/retrained_labels_length.txt"
model_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_graph_shape.pb"
label_file_shape = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/shape/retrained_labels_shape.txt"
model_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_graph_sleeve.pb"
label_file_sleeve = "/home/ubuntu/amit/feature-sensing/tf_files/attribute/dresses/sleeve/retrained_labels_sleeve.txt"
with gfile.FastGFile(model_file_hemline, "rb") as f:
model_data = f.read()
model_data_hemline = sc.broadcast(model_data)
with gfile.FastGFile(model_file_length, "rb") as f:
model_data = f.read()
model_data_length = sc.broadcast(model_data)
with gfile.FastGFile(model_file_shape, "rb") as f:
model_data = f.read()
model_data_shape = sc.broadcast(model_data)
with gfile.FastGFile(model_file_sleeve, "rb") as f:
model_data = f.read()
model_data_sleeve = sc.broadcast(model_data)
def calculate(row):
path = "/tmp/"+row.guid
url = row.modelno
print(path, url)
if(url is not None):
import urllib.request
urllib.request.urlretrieve(url, path)
t1=time.time()
result = getLabelDresses(path)
print(time.time()-t1)
print(result)
return row
return row
product2.rdd.map(calculate).collect()
答案 0 :(得分:1)
代码中对getLabelDresses
的每次调用都会向图表添加操作。
将代码拆分为设置(模型加载)部分,执行一次,并为每个图像执行执行部分。后者应仅包含对Session.run
的调用。
另一种选择是在使用tf.reset_default_graph
处理下一张图像之前清除图表。但它不太可取。