tf.app.run()如何工作?

时间:2015-11-14 00:03:33

标签: python python-3.x tensorflow

tf.app.run()如何在Tensorflow翻译演示?

tensorflow/models/rnn/translate/translate.py中,有tf.app.run()的电话。它是如何处理的?

if __name__ == "__main__":
    tf.app.run() 

6 个答案:

答案 0 :(得分:113)

if __name__ == "__main__":

表示当前文件在shell下执行,而不是作为模块导入。

tf.app.run()

正如您可以看到文件app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们一行一行:

flags_passthrough = f._parse_flags(args=args)

这可确保您通过命令行传递的参数有效,例如。 python my_model.py --data_dir='...' --max_iteration=10000实际上,此功能是基于python标准argparse模块实现的。

main = main or sys.modules['__main__'].main

main右侧的第一个=是当前函数run(main=None, argv=None)的第一个参数 。虽然sys.modules['__main__']表示当前正在运行的文件(例如my_model.py)。

所以有两种情况:

  1. main中没有my_model.py功能,那么你必须这样做 致电tf.app.run(my_main_running_function)

  2. main中有my_model.py个功能。 (情况大多如此。)

  3. 最后一行:

    sys.exit(main(sys.argv[:1] + flags_passthrough))
    

    确保使用正确的解析参数调用main(argv)my_main_running_function(argv)函数。

答案 1 :(得分:72)

它只是一个非常快速的包装器,它处理标志解析,然后发送到您自己的主服务器。请参阅code

答案 2 :(得分:5)

tf.app没有什么特别之处。这只是generic entry point script

  

使用可选的' main'运行程序。功能和' argv'列表。

它与神经网络无关,只是调用main函数,通过任何参数传递给它。

答案 3 :(得分:3)

简单来说, tf.app.run() 的工作是首先设置全局标志供以后使用,如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后使用一组参数运行自定义主函数。

例如在TensorFlow NMT代码库中,此时开始进行训练/推理的程序执行的第一个入口点(参见下面的代码)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使用argparse解析参数后,使用tf.app.run()运行“main”函数,其定义如下:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置全局使用的标志后, tf.app.run() 只需运行您传递给它的main函数,并将argv作为参数。< / p>

PS:正如Salvador Dali's answer所说,这只是一个很好的软件工程实践,我想,虽然我不确定TensorFlow是否执行main函数的任何优化运行,而不是使用正常运行CPython的。

答案 4 :(得分:1)

Google代码在很大程度上取决于在库/二进制文件/ python脚本中正在访问的全局标志,因此tf.app.run()解析出这些标志以在FLAGs(或类似的东西)变量中创建全局状态,然后调用python main()应该。

如果没有对此tf.app.run()的调用,则用户可能会忘记进行FLAG解析,从而导致这些库/二进制文件/脚本无法访问所需的FLAG。

答案 5 :(得分:1)

2.0兼容答案:如果要在 tf.app.run() 中使用Tensorflow 2.0,则应使用命令

tf.compat.v1.app.run(),也可以使用tf_upgrade_v21.x代码转换为2.0