实例化multigpu模型时Keras失去图形连接

时间:2018-06-03 18:17:06

标签: python tensorflow keras

当在TF后端实例化具有多个输出/输入的multigpu模型时,Keras会丢失一些图形连接,导致无梯度破坏模型编译。单个GPU模型将编译没有问题,但与multigpu模型相同的模型失败。请参阅this example

我确实制作了Keras issue,但我正在寻找补救措施。我认为可能有一个聪明的图层/模型命名的修复程序,如果你在Keras后端找到足够的东西,我确定有一个修复程序;特别是本节(204-227,multi_gpu_utils.py):

# Place a copy of the model on each GPU,
# each getting a slice of the inputs.
for i, gpu_id in enumerate(target_gpu_ids):
    with tf.device('/gpu:%d' % gpu_id):
        with tf.name_scope('replica_%d' % gpu_id):
            inputs = []
            # Retrieve a slice of the input.
            for x in model.inputs:
                input_shape = tuple(x.get_shape().as_list())[1:]
                slice_i = Lambda(get_slice,
                                 output_shape=input_shape,
                                 arguments={'i': i,
                                            'parts': num_gpus})(x)
                inputs.append(slice_i)

            # Apply model on slice
            # (creating a model replica on the target device).
            outputs = model(inputs)
            if not isinstance(outputs, list):
                outputs = [outputs]

            # Save the outputs for merging back together later.
            for o in range(len(outputs)):
                all_outputs[o].append(outputs[o])

但我不确定。任何想法/修复都非常感激。

0 个答案:

没有答案