我正在尝试在TPU上训练变压器编码器(从这里开始https://www.tensorflow.org/tutorials/text/transformer)
def test():
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
with tf.GradientTape() as tape:
predictions = transf(inp,
True,
None,
None,
None)
loss = loss_function(tar, predictions) # <- error is here
# I use SparseCategoricalCrossentropy()
vocabsize=1000
transf = Transformer(num_layers, d_model, num_heads, dff,
vocabsize, vocabsize,
pe_input=vocabsize,
pe_target=vocabsize,
rate=dropout_rate)
for iter in range(1,75000):
print(iter)
inp=np.random.randint(vocabsize, size=(5,11))
tar=np.random.randint(vocabsize, size=(5,11))
train_step(inp,tar)
它在CPU上工作。但是在TPU上进行约100次迭代后,调用loss_function时出现错误(如上所示):
InvalidArgumentError:
以下节点调用的函数不可编译:{{node __inference_train_step_4179}} = __inference_train_step_4179 [_XlaMustCompile = true,config_proto =“ \ n \ 007 \ n \ 003GPU \ 020 \ 000 \ n \ 007 \ n \ 003CPU \ 020 \ 0012 \ 002J \ 0008 \ 001“,executor_type =”“(虚拟输入,虚拟输入,虚拟输入,虚拟输入...
无法编译的节点: sparse_categorical_crossentropy / SparseSoftmaxCrossEntropyWithLogits / assert_equal_1 / Assert / Const:
不受支持的操作:XLA不支持DT_STRING类型的Const操作。
Stacktrace:节点:__inference_train_step_4179,功能:节点: sparse_categorical_crossentropy / SparseSoftmaxCrossEntropyWithLogits / assert_equal_1 / Assert / Const,函数:__inference_train_step_4179 ...
据我所知-该错误是由Xla不支持的损失函数内的断言引起的。 我在这里可以做什么?