我正在尝试使用tf.distribute.MirroredStrategy()在Tensorflow 2(每晚)和Keras上进行多GPU训练。
问题在于,通过将2倍大的批处理总大小与2个GPU(16)配合使用,训练时间将每个时期降低3倍。 有人有线索可以找到我的问题吗? tf.distribute.MirroredStrategy已经为我们完成了所有工作,并在单个GPU上快速进行了模型训练。
它是一个单输入2输出模型。造成这种情况的可能性更大?
这是模型摘要:
Model: "Joined_Model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
inp (InputLayer) [(None, None, 257)] 0
__________________________________________________________________________________________________
time_distributed_2 (TimeDistrib (None, None, 512) 131584 inp[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, None, 512) 1024 time_distributed_2[0][0]
__________________________________________________________________________________________________
re_lu_1 (ReLU) (None, None, 512) 0 layer_normalization_1[0][0]
__________________________________________________________________________________________________
lstm_5 (LSTM) (None, 512) 2099200 re_lu_1[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, None, 512) 0 re_lu_1[0][0]
lstm_5[0][0]
__________________________________________________________________________________________________
lstm_6 (LSTM) (None, 512) 2099200 add_5[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, None, 512) 0 add_5[0][0]
lstm_6[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM) (None, 512) 2099200 add_6[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, None, 512) 0 add_6[0][0]
lstm_7[0][0]
__________________________________________________________________________________________________
lstm_8 (LSTM) (None, 512) 2099200 add_7[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, None, 512) 0 add_7[0][0]
lstm_8[0][0]
__________________________________________________________________________________________________
lstm_9 (LSTM) (None, 512) 2099200 add_8[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, None, 512) 0 add_8[0][0]
lstm_9[0][0]
__________________________________________________________________________________________________
time_distributed_3 (TimeDistrib (None, None, 257) 131841 add_9[0][0]
__________________________________________________________________________________________________
out1(Activation) (None, None, 257) 0 time_distributed_3[0][0]
__________________________________________________________________________________________________
tf_op_layer_Add_1 (TensorFlowOp [(None, None, 257)] 0 out1[0][0]
__________________________________________________________________________________________________
tf_op_layer_RealDiv_1 (TensorFl [(None, None, 257)] 0 out1[0][0]
tf_op_layer_Add_1[0][0]
__________________________________________________________________________________________________
tf_op_layer_Sqrt_1 (TensorFlowO [(None, None, 257)] 0 tf_op_layer_RealDiv_1[0][0]
__________________________________________________________________________________________________
tf_op_layer_Mul_1 (TensorFlowOp [(None, None, 257)] 0 inp[0][0]
tf_op_layer_Sqrt_1[0][0]
__________________________________________________________________________________________________
log_mel_spectrogram_1 (LogMelSp (None, None, 160) 0 tf_op_layer_Mul_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, None, 160, 1) 0 log_mel_spectrogram_1[0][0]
__________________________________________________________________________________________________
conv_1 (Conv2D) (None, None, 80, 32) 14432 lambda_1[0][0]
__________________________________________________________________________________________________
conv_1_bn (BatchNormalization) (None, None, 80, 32) 128 conv_1[0][0]
__________________________________________________________________________________________________
conv_1_relu (ReLU) (None, None, 80, 32) 0 conv_1_bn[0][0]
__________________________________________________________________________________________________
conv_2 (Conv2D) (None, None, 40, 32) 236544 conv_1_relu[0][0]
__________________________________________________________________________________________________
conv_2_bn (BatchNormalization) (None, None, 40, 32) 128 conv_2[0][0]
__________________________________________________________________________________________________
conv_2_relu (ReLU) (None, None, 40, 32) 0 conv_2_bn[0][0]
__________________________________________________________________________________________________
after_conv (Reshape) (None, None, 1280) 0 conv_2_relu[0][0]
__________________________________________________________________________________________________
bidirectional_1 (Bidirectional) (None, None, 1600) 9993600 after_conv[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, None, 1600) 0 bidirectional_1[0][0]
__________________________________________________________________________________________________
bidirectional_2 (Bidirectional) (None, None, 1600) 11529600 dropout_4[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, None, 1600) 0 bidirectional_2[0][0]
__________________________________________________________________________________________________
bidirectional_3 (Bidirectional) (None, None, 1600) 11529600 dropout_5[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, None, 1600) 0 bidirectional_3[0][0]
__________________________________________________________________________________________________
bidirectional_4 (Bidirectional) (None, None, 1600) 11529600 dropout_6[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout) (None, None, 1600) 0 bidirectional_4[0][0]
__________________________________________________________________________________________________
bidirectional_5 (Bidirectional) (None, None, 1600) 11529600 dropout_7[0][0]
__________________________________________________________________________________________________
dense_1 (TimeDistributed) (None, None, 1600) 2561600 bidirectional_5[0][0]
__________________________________________________________________________________________________
dense_1_relu (ReLU) (None, None, 1600) 0 dense_1[0][0]
__________________________________________________________________________________________________
dense_1_bn (BatchNormalization) (None, None, 1600) 6400 dense_1_relu[0][0]
__________________________________________________________________________________________________
out2(TimeDistributed) (None, None, 29) 46429 dense_1_bn[0][0]
==================================================================================================