在pytorch数据加载器中从远程服务器获取数据

时间:2020-02-21 13:17:43

标签: python pytorch

我有一个很大的hd5文件(〜100GB),其中包含来自resnet的图像功能。该文件位于我的本地计算机(笔记本电脑)上。我的模型是在存储限制为25GB的群集节点上训练的。 现在,我正在使用torch.distributed.rpc将数据从本地计算机传输到群集。 我将通过以下方式在本地计算机上公开服务器,

num_worker = 4

utils.WORLD_SIZE = num_worker +1

import os
import torch
import utils
import torch.distributed.rpc as rpc

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8888'
    
    rpc.init_rpc(utils.SERVER_NAME,
                 rank = rank, 
                 world_size = world_size)
    print("Server Initialized", flush=True)
    rpc.shutdown()

if __name__ == "__main__":
    rank = 0
    world_size = utils.WORLD_SIZE
    run_worker(rank, world_size)

此服务器将数据从本地计算机发送到群集。 (其他类被省略)

现在要从集群请求数据,我正在使用 worker_init_fn 用于数据加载器,为每个工作程序初始化rpc,

def worker_init_fn(worker_id):
    rpc.init_rpc(utils.CLIENT_NAME.format(worker_id+1),
                rank=worker_id+1, world_size=utils.WORLD_SIZE)
    
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    worker_id = worker_info.id
    server_info = rpc.get_worker_info(utils.SERVER_NAME)
    dataset.server_ref = rpc.remote(server_info, utils.Server)

现在,当我运行我的代码时,训练循环完成了数据集的一次迭代并在此之后挂起,并且在集群方面出现了以下错误,

Traceback (most recent call last):
  File "custom_datasets.py", line 134, in <module>
    main()
  File "custom_datasets.py", line 110, in main
    for i, (images, labels) in enumerate(mn_dataset_loader):
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 135, in _worker_loop
    init_fn(worker_id)
  File "custom_datasets.py", line 78, in worker_init_fn
    rank=worker_id+1, world_size=utils.WORLD_SIZE)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/distributed/rpc/__init__.py", line 67, in init_rpc
    store, _, _ = next(rendezvous_iterator)
  File "/home/kanishk/.local/lib/python3.7/site-packages/torch/distributed/rendezvous.py", line 168, in _env_rendezvous_handler
    store = TCPStore(master_addr, master_port, world_size, start_daemon)
RuntimeError: connect() timed out.

当我设置num_worker = 0时,上面的问题没有发生,但是群集代码非常慢。我认为该错误是由于多线程引起的,但我不确定如何解决此问题。请帮助我解决问题。

0 个答案:

没有答案