TF 2.0 @ tf.function示例

时间:2019-03-22 13:20:46

标签: python tensorflow tensorflow2.0

this PhoneGap Build support page部分的tensorflow文档中,我们具有以下代码段

@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

我有一个关于step变量的小问题,它是一个整数而不是张量,签名支持内置的python类型,例如integer。因此,可以将tf.equal(step%10,0)更改为简单的step%10 == 0对吗?

1 个答案:

答案 0 :(得分:2)

是的,您是对的。即使将整数变量step转换为其图形表示形式,它仍然是Python变量。您可以通过调用tf.autograph.to_code(train.python_function)来查看转换结果。

无需报告所有代码,仅报告与变量step相关的部分,

  def loop_body(loop_vars, loss_1, step_1):
    with ag__.function_scope('loop_body'):
      x, y = loop_vars
      step_1 += 1

仍然是python操作(否则,如果第1步是step_1.assign_add(1),它将是tf.Tensor)。

有关签名和tf.function的更多信息,建议阅读文章https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/,该文章轻松地解释了转换函数时会发生什么情况。