torch.multiprocessing.Queue不加速

时间:2019-12-30 04:42:09

标签: python deep-learning multiprocessing shared-memory torch

我的训练系统由一系列过程组成,这些过程以张量或张量列表/字典的形式交换数据。通过torch.multiprocessing模块进行内存共享是一种加速相似工作流程的已知技术。但是由于某种原因,它对我的​​应用没有帮助。

这是一个模拟系统的测试脚本,我们创建一个进程并通过队列发送张量:

import sys
import time

import torch
from torch.multiprocessing import Process as TorchProcess
from torch.multiprocessing import Queue as TorchQueue


q = TorchQueue()


def torch_shared_mem_process():
    counter = 0

    while True:
        data = q.get()
        counter += 1

        if data is None:
            return

        print('Received data:', len(data), data, counter)


def test_mem_share(share_memory):
    p = TorchProcess(target=torch_shared_mem_process)
    p.start()

    def sample_data():
        return torch.rand([1000, 128, 72, 3], dtype=torch.float)

    start = time.time()

    n = 50
    for i in range(n):
        data = sample_data()

        for data_item in data:
            if share_memory:
                data_item.share_memory_()

        q.put(data)

        print(f'Progress {i}/{n}')

    q.put(None)
    p.join()

    print(f'Finished sending {n} tensor lists!')

    took_seconds = time.time() - start
    return took_seconds


def main():
    no_shared_memory = test_mem_share(share_memory=False)
    with_shared_memory = test_mem_share(share_memory=True)

    print(f'Took {no_shared_memory:.1f} s without shared memory.')
    print(f'Took {with_shared_memory:.1f} s with shared memory.')


if __name__ == '__main__':
    sys.exit(main())

由于我使用的是torch.multiprocessing,因此我希望使用share_memory=True的版本会更快,但实际上,它实际上要慢一些:

Took 10.2 s without shared memory.
Took 11.7 s with shared memory.

我是否误解了torch.multiprocessing.Queue的工作方式?

1 个答案:

答案 0 :(得分:0)

我相信torch.multiprocessing.Queue在传输张量时已经将它们移动到共享内存中,因此data_item.share_memory_()不应进一步加快速度。