没有NCCL的MirroredStrategy

时间:2018-06-05 13:36:50

标签: tensorflow

  • 我是否编写过自定义代码(与使用TensorFlow中提供的库存示例脚本相反):否
  • 操作系统平台和分发(例如,Linux Ubuntu 16.04):Windows 10 x64
  • 从(源代码或二进制代码)安装的TensorFlow :二进制文件
  • TensorFlow版本(使用下面的命令):1.8.0
  • Python版:3.6
  • Bazel版本(如果从源代码编译): -
  • GCC /编译器版本(如果从源代码编译): -
  • CUDA / cuDNN版本:9.0
  • GPU型号和内存:3.5
  • 重现的确切命令:simple_tfkeras_example.py

我想使用MirroredStrategy在同一台机器上使用多个GPU。我试过其中一个例子: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py

结果是: ValueError:操作类型未注册' NcclAllReduce'在RAID上运行二进制文件。确保在此过程中运行的二进制文件中注册了Op和Kernel。在构建NodeDef' NcclAllReduce'

我使用的是Windows,因此Nccl不可用。是否有可能强制TensorFlow不使用此库?

1 个答案:

答案 0 :(得分:0)

Windows上有一些用于NCCL的二进制文件,但是处理起来很烦人。

作为替代方案,Tensorflow在MirroredStrategy中为您提供了三个与Windows本机兼容的其他选项。它们是分层副本,精简至第一个GPU和精简至CPU。您最可能希望找到的是“层次副本”,但是您可以对其进行测试,以查看能给您带来最佳效果的东西。

如果您使用的Tensorflow版本早于2.0,则将使用tf.contrib.distribute:

# Hierarchical Copy
cross_tower_ops = tf.contrib.distribute.AllReduceCrossTowerOps(
        'hierarchical_copy', num_packs=number_of_gpus))
    strategy = tf.contrib.distribute.MirroredStrategy(cross_tower_ops=cross_tower_ops)

# Reduce to First GPU
cross_tower_ops = tf.contrib.distribute. ReductionToOneDeviceCrossTowerOps()
strategy = tf.contrib.distribute.MirroredStrategy(cross_tower_ops=cross_tower_ops)

# Reduce to CPU
cross_tower_ops = tf.contrib.distribute. ReductionToOneDeviceCrossTowerOps(
    reduce_to_device="/device:CPU:0")
strategy = tf.contrib.distribute.MirroredStrategy(cross_tower_ops=cross_tower_ops)

2.0之后,您只需要使用tf.distribute!以下是使用两个GPU设置Xception模型的示例:

strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"], 
                                          cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
with strategy.scope():
    parallel_model = Xception(weights=None,
                              input_shape=(299, 299, 3),
                              classes=number_of_classes)
    parallel_model.compile(loss='categorical_crossentropy', optimizer='rmsprop')