在TPU上训练时,如何在tensorFlow中使用交叉熵损失?

时间:2020-06-23 12:47:23

标签: python tensorflow tpu

我正在尝试在TPU上训练变压器编码器(从这里开始https://www.tensorflow.org/tutorials/text/transformer

def test():
    train_step_signature = [
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
        tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    ]
    @tf.function(input_signature=train_step_signature)
    def train_step(inp, tar):
      with tf.GradientTape() as tape:
        predictions = transf(inp,  
                                     True, 
                                     None, 
                                     None, 
                                     None)
        loss = loss_function(tar, predictions) # <- error is here
                                               # I use SparseCategoricalCrossentropy()

    vocabsize=1000
    transf = Transformer(num_layers, d_model, num_heads, dff,
                              vocabsize, vocabsize, 
                              pe_input=vocabsize, 
                              pe_target=vocabsize,
                              rate=dropout_rate)
    for iter in range(1,75000):
      print(iter)
      inp=np.random.randint(vocabsize, size=(5,11))
      tar=np.random.randint(vocabsize, size=(5,11))
      train_step(inp,tar)

它在CPU上工作。但是在TPU上进行约100次迭代后,调用loss_function时出现错误(如上所示):

InvalidArgumentError:

以下节点调用的函数不可编译:{{node __inference_train_step_4179}} = __inference_train_step_4179 [_XlaMustCompile = true,config_proto =“ \ n \ 007 \ n \ 003GPU \ 020 \ 000 \ n \ 007 \ n \ 003CPU \ 020 \ 0012 \ 002J \ 0008 \ 001“,executor_type =”“(虚拟输入,虚拟输入,虚拟输入,虚拟输入...

无法编译的节点: sparse_categorical_crossentropy / SparseSoftmaxCrossEntropyWithLogits / assert_equal_1 / Assert / Const:

不受支持的操作:XLA不支持DT_STRING类型的Const操作。

Stacktrace:节点:__inference_train_step_4179,功能:节点: sparse_categorical_crossentropy / SparseSoftmaxCrossEntropyWithLogits / assert_equal_1 / Assert / Const,函数:__inference_train_step_4179 ...

据我所知-该错误是由Xla不支持的损失函数内的断言引起的。 我在这里可以做什么?

0 个答案:

没有答案