当在TF后端实例化具有多个输出/输入的multigpu模型时,Keras会丢失一些图形连接,导致无梯度破坏模型编译。单个GPU模型将编译没有问题,但与multigpu模型相同的模型失败。请参阅this example。
我确实制作了Keras issue,但我正在寻找补救措施。我认为可能有一个聪明的图层/模型命名的修复程序,如果你在Keras后端找到足够的东西,我确定有一个修复程序;特别是本节(204-227,multi_gpu_utils.py):
# Place a copy of the model on each GPU,
# each getting a slice of the inputs.
for i, gpu_id in enumerate(target_gpu_ids):
with tf.device('/gpu:%d' % gpu_id):
with tf.name_scope('replica_%d' % gpu_id):
inputs = []
# Retrieve a slice of the input.
for x in model.inputs:
input_shape = tuple(x.get_shape().as_list())[1:]
slice_i = Lambda(get_slice,
output_shape=input_shape,
arguments={'i': i,
'parts': num_gpus})(x)
inputs.append(slice_i)
# Apply model on slice
# (creating a model replica on the target device).
outputs = model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
# Save the outputs for merging back together later.
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
但我不确定。任何想法/修复都非常感激。