我正在尝试使用分布式训练教程调整 TensorFlow 的转换器教程以在多个 GPU 上工作,
变压器:https://www.tensorflow.org/tutorials/text/transformer 分布式训练:https://www.tensorflow.org/tutorials/distribute/custom_training
但是,当我执行程序时,我收到错误消息:per_replica_losses = strategy.run(train_step, args=(inp, tar),)
中的“TypeError: train_step() 参数后面的 ** 必须是映射,而不是张量”。
导致问题的代码如下:
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
#inp, tar = dataset_inputs
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape() as tape:
predictions, _ = transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = loss_function(tar_real, predictions)
gradients = tape.gradient(loss, transformer.trainable_variables)
optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))
train_loss(loss)
train_accuracy(accuracy_function(tar_real, predictions))`
@tf.function(input_signature=train_step_signature)
def distributed_train_step(inp, tar):
per_replica_losses = strategy.run(train_step, args=(inp, tar),)
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
for epoch in range(EPOCHS):
start = time.time()
BatchTime = time.time()
train_loss.reset_states()
train_accuracy.reset_states()
for (batch, (dataset_inputs)) in enumerate(train_dataset):
distributed_train_step(dataset_inputs)
if batch % 50 == 0:
print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, train_loss.result(), train_accuracy.result()))
if batch % 100 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for batch {} at {}'.format(batch,
ckpt_save_path))
print("Time taken: {}".format(time.time() - BatchTime))
BatchTime = time.time()
print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1,
train_loss.result(),
train_accuracy.result()))
print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))```
Thanks