我有一个图像分类器,试图将图像分为三类。我使用开源TensorFlow为诗人训练我的模型:https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#4
我有一个大约350,000个图像的列表,我想要分类,我正在使用带有线程的多进程在本地下载一批图像,对它们进行分类,然后清除图像目录。
多处理/线程:
def download_image(photo_url):
ext = photo_url.split('.')[-1]
if ext in extensions:
image_data = requests.get(photo_url, timeout=5).content
photo_name = '{}.{}'.format(os.urandom(12).encode('hex'), ext)
photo_dest = os.path.join(tmp_images_dir, photo_name)
with open(photo_dest, 'wb') as f:
f.write(image_data)
def thread_handler(image_list):
jobs = []
for image in image_list.split(','):
thread = threading.Thread(target=download_image, args=(image,))
jobs.append(thread)
thread.start()
for j in jobs:
j.join()
def multi_download_images(image_list):
pool = Pool(processes=cpu_count() - 1)
print 'downloading images...'
pool.map(thread_handler,image_list)
pool.close()
pool.join()
图片分类:
def classify_image_batch():
print 'labeling images...'
image_list = [f for f in os.listdir(tmp_images_dir)]
with tf.gfile.FastGFile(graph_dir, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
results = []
with tf.Session() as sess:
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
bar = ProgressBar()
for image in bar(image_list):
try:
image_path = os.path.join(tmp_images_dir, image)
image_data = tf.gfile.FastGFile(image_path, 'rb').read()
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
category = label_lines[top_k[0]]
results.append((image, category))
except:
results.append((image, 'failed'))
update_mongo_with_classification(results)
sess.close()
循环
def loop_image_classification():
if tf.gfile.Exists(tmp_images_dir):
tf.gfile.DeleteRecursively(tmp_images_dir)
tf.gfile.MakeDirs(tmp_images_dir)
multi_download_images()
classify_image_batch()
gc.collect()
我循环下载,分类,用最后的gc.collect()语句清除以清理任何内容。
这在前几次迭代中运行顺利,但后来我注意到我的机器停止使用它的所有内核并依赖单个内核进行下载和分类 - 产生正常性能的1/16。我的直觉是,当我关闭并加入我的线程和池时,TensorFlow会发生泄漏。 TensorFlow在前几个快速分类循环中使用了我的所有核心。
我已经读过关于graph_def增长过大/为我的图表添加节点的每个jpeg读入。但事实上,TensorFlow幕后发生的事情对我来说仍然是一个谜。我很感激任何方向可以帮助我解决这个问题/让我在TensorFlow上做得更好。
谢谢!