Transformer 模型的 Tensorflow 分布式训练

时间:2021-01-27 22:19:08

标签: python tensorflow machine-learning keras deep-learning

我正在尝试使用分布式训练教程调整 TensorFlow 的转换器教程以在多个 GPU 上工作,

变压器:https://www.tensorflow.org/tutorials/text/transformer 分布式训练:https://www.tensorflow.org/tutorials/distribute/custom_training

但是,当我执行程序时,我收到错误消息:per_replica_losses = strategy.run(train_step, args=(inp, tar),) 中的“TypeError: train_step() 参数后面的 ** 必须是映射,而不是张量”。

导致问题的代码如下:

tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
#inp, tar = dataset_inputs
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]

enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

with tf.GradientTape() as tape:
predictions, _ = transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = loss_function(tar_real, predictions)

gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))`

@tf.function(input_signature=train_step_signature)
def distributed_train_step(inp, tar):
per_replica_losses = strategy.run(train_step, args=(inp, tar),)

return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)

for epoch in range(EPOCHS):
start = time.time()
BatchTime = time.time()

train_loss.reset_states()
train_accuracy.reset_states()

for (batch, (dataset_inputs)) in enumerate(train_dataset):

  distributed_train_step(dataset_inputs)

  if batch % 50 == 0:

       print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, train_loss.result(), train_accuracy.result()))

  if batch % 100 == 0:
      ckpt_save_path = ckpt_manager.save()
      print ('Saving checkpoint for batch {} at {}'.format(batch,
                                                             ckpt_save_path))
      print("Time taken: {}".format(time.time() - BatchTime))
      BatchTime = time.time()
print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1,
train_loss.result(),
train_accuracy.result()))
print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))```

Thanks

0 个答案:

没有答案