在chainer中,如何使用多个GPU编写BPTT更新程序?

时间:2017-08-21 03:07:25

标签: chainer

我没有找到示例,因为现有示例只扩展了training.StandardUpdater,因此只使用一个GPU。

1 个答案:

答案 0 :(得分:0)

我假设你在谈论the ptb example of ChainerBPTTUpdater

在多个GPU上进行自定义更新程序支持学习并不是直截了当的。 MultiprocessParallelUpdater硬代码计算渐变的方式(只有目标链接实现是可自定义的),因此您必须复制MultiprocessParallelUpdater的整体实现并修改渐变计算部分。您需要复制和编辑的内容是chainer/training/updaters/multiprocess_parallel_updater.py

此文件中有两个部分用于计算渐变; _Worker.run中的一个表示工作进程任务,另一个表示MultiprocessParallelUpdater.update_core,表示主进程任务。您必须通过在以下两部分中修改从_calc_lossbackward的代码来使这些代码执行BPTT:

# Change self._master into self.model for _Worker.run code
loss = _calc_loss(self._master, batch)
self._master.cleargrads()
loss.backward()

应该通过插入BPTTUpdater.update_core

的代码进行修改

您还必须注意数据迭代器。 MultiprocessParallelUpdater接受将分发给主/工作进程的迭代器集。由于ptb示例使用自定义迭代器(ParallelSequentialIterator),因此必须确保这些迭代器迭代数据集的不同部分或使用不同的单词位置初始偏移量。它也可能需要定制到ParalellSequentialIterator