用多线程编写tfrecord的速度不如预期的快

时间:2018-07-24 17:13:20

标签: multithreading performance tensorflow tfrecord

试图编写不带多线程的tfrecord,发现速度差异不大(带4个线程:434秒;不带多线程590秒)。不知道我是否使用正确。有没有更好的方法可以更快地编写tfrecord?

import tensorflow as tf 
import numpy as np 
import threading 
import time 


def generate_data(shape=[15,28,60,1]):
    return np.random.uniform(size=shape)


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def write_instances_to_tfrecord(tfrecord_file, filenames):
    tfrecord_writer = tf.python_io.TFRecordWriter(tfrecord_file)
    for i, filename in enumerate(filenames):
        curr_MFCC = generate_data()
        curr_MFCC_raw = curr_MFCC.tostring()
        curr_filename_raw = str(filename)+'-'+str(i)
        example = tf.train.Example(features=tf.train.Features(
            feature={
            'MFCC': _bytes_feature(curr_MFCC_raw),
            'filename': _bytes_feature(curr_filename_raw)
            })
        )
        tfrecord_writer.write(example.SerializeToString())
    tfrecord_writer.close()


def test():
    threading_start = time.time()
    coord = tf.train.Coordinator()
    threads = []
    for thread_index in xrange(4):
        args = (str(thread_index), range(200000))
        t = threading.Thread(target=write_instances_to_tfrecord, args=args)
        t.start()
        threads.append(t)
    coord.join(threads)
    print 'w/ threading takes', time.time()-threading_start

    start = time.time()
    write_instances_to_tfrecord('5', range(800000))
    print 'w/o threading takes', time.time()-start

if __name__ == '__main__':
    test()

2 个答案:

答案 0 :(得分:1)

使用python线程时,由于cPython实现中的GIL限制,CPU利用率将限制为1个内核。无论您添加多少线程,都不会看到速度加快。

您的情况下,一个简单的解决方案是使用multiprocessing模块。 代码几乎与您所拥有的完全相同,只需将线程切换到进程即可:

from multiprocessing import Process
coord = tf.train.Coordinator()
processes = []
for thread_index in xrange(4):
    args = (str(thread_index), range(200000))
    p = Process(target=write_instances_to_tfrecord, args=args)
    p.start()
    processes.append(p)
coord.join(processes)

我在自己的tfrecord编写器代码上进行了测试,并获得了线性缩放加速。进程总数受内存限制。

答案 1 :(得分:0)

最好使用Tensorflow计算图来利用多线程,因为每个会话和图都可以在不同的线程中运行。有了计算图,它快了大约40倍。