在GPU(Tensorflow)上使用时,tf.argmax不起作用

时间:2019-01-24 20:58:30

标签: tensorflow neural-network

我在GPU上使用tf.argmax时遇到问题。问题是以下

以下代码

for i in range(10): # Loop for the Epochs
  print ("\nEpoch:", i)

  for (batch, (images, labels)) in enumerate(dataset.take(60000)): # Loop for the mini-batches
    if batch % 100 == 0:
      #print('batches processed', batch)
      print('.', end='')
    labels = tf.cast(labels, dtype = tf.int64)

    with tf.device('/gpu:0'):
      with tf.GradientTape() as tape:
        logits = mnist_model(images, training=True)
        #print(logits)
        i64 = tf.constant(1, dtype=tf.int64)
        tgmax = tf.argmax(labels, axis = i64)
        loss_value = tf.losses.sparse_softmax_cross_entropy(tgmax, logits)

        loss_history.append(loss_value.numpy())
        grads = tape.gradient(loss_value, mnist_model.variables)
        optimizer.apply_gradients(zip(grads, mnist_model.variables),
                                    global_step=tf.train.get_or_create_global_step())

在CPU上有效,但在GPU上无效(如所写)。这是对MNIST数据集使用急切执行时的简单测试。我认为该错误与数据类型有关,但我不确定。我得到的错误是

NotFoundError: No registered 'ArgMax' OpKernel for GPU devices compatible with node {{node ArgMax}} = ArgMax[T=DT_INT64, Tidx=DT_INT64, output_type=DT_INT64](dummy_input, dummy_input)
     (OpKernel was found, but attributes didn't match)
    .  Registered:  device='XLA_GPU'; output_type in [DT_INT32, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, DT_COMPLEX64, DT_INT64, DT_QINT8, DT_QUINT8, DT_QINT32, DT_BFLOAT16, DT_HALF, DT_UINT32, DT_UINT64]; Tidx in [DT_INT32, DT_INT64]
  device='XLA_GPU_JIT'; output_type in [DT_INT32, DT_INT64]; T in [DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT8, DT_COMPLEX64, DT_INT64, DT_QINT8, DT_QUINT8, DT_QINT32, DT_BFLOAT16, DT_HALF, DT_UINT32, DT_UINT64]; Tidx in [DT_INT32, DT_INT64]
  device='GPU'; T in [DT_DOUBLE]; output_type in [DT_INT32]; Tidx in [DT_INT32]

任何人都有简单的解决方案吗?预先感谢,Umberto

1 个答案:

答案 0 :(得分:0)

看起来像The yield statement中提到的错误。请将该问题报告给github上的tensorflow。