我正在GPU上执行PyTorch函数n_iters次。目前,我为此使用for循环。但是,当n_iters大时,效率很低。我想知道是否有具有parallel_iterations功能的tf.map_fn的PyTorch等效项,以便我可以并行执行所有迭代?
答案 0 :(得分:2)
我进行了深入的搜索,但无法在pytorch中找到任何等效于tf.map_fn的函数,该函数暴露了用户要设置的parallel_iterations数量。
在探索的过程中,我发现有一个名为“ nn.DataParallel”的函数,但是此函数复制了要在多个GPU上运行的模型或操作,然后返回与不等的结果tf.map_fn中的parallel_iterations数量。
但是目前在Pytorch中,使用for循环是目前唯一的方法。