Tensorflow 2:获取张量值

时间:2019-07-05 09:47:23

标签: python tensorflow tensorflow2.0

我正在切换到TF2,我只是遵循了tutorial, 火车和踏步功能现在定义为“ @ tf.function”。

如何打印张量y_pred和loss的值?

@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)

    print("train preds: ", y_pred)
    print("train loss: ", loss)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

1 个答案:

答案 0 :(得分:2)

print在Python世界中执行(不在图形中执行),因此,在tf.function跟踪函数构造图形时,它只会打印张量一次。如果要打印图形,请使用tf.print