适用于多个GPU的TensorFlow数据集

时间:2019-05-06 03:14:32

标签: tensorflow deep-learning distributed-computing data-pipeline

我打算用多个GPU训练模型,我想知道为什么我确实需要使用分片。

我有以下代码。

dataset = ...
iterator = dataset.make_one_shot_iterator()

for i in range(FLAGS.num_gpus):
    with tf.device('/gpu:%d' % i):
        data = iterator.get_next()
        loss = ...

这样做是错误的方式吗?使用dataset.shard(...)更有利吗?

0 个答案:

没有答案