tf.app.run()
如何在Tensorflow翻译演示?
在tensorflow/models/rnn/translate/translate.py
中,有tf.app.run()
的电话。它是如何处理的?
if __name__ == "__main__":
tf.app.run()
答案 0 :(得分:113)
if __name__ == "__main__":
表示当前文件在shell下执行,而不是作为模块导入。
tf.app.run()
正如您可以看到文件app.py
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
让我们一行一行:
flags_passthrough = f._parse_flags(args=args)
这可确保您通过命令行传递的参数有效,例如。
python my_model.py --data_dir='...' --max_iteration=10000
实际上,此功能是基于python标准argparse
模块实现的。
main = main or sys.modules['__main__'].main
main
右侧的第一个=
是当前函数run(main=None, argv=None)
的第一个参数
。虽然sys.modules['__main__']
表示当前正在运行的文件(例如my_model.py
)。
所以有两种情况:
main
中没有my_model.py
功能,那么你必须这样做
致电tf.app.run(my_main_running_function)
main
中有my_model.py
个功能。 (情况大多如此。)
最后一行:
sys.exit(main(sys.argv[:1] + flags_passthrough))
确保使用正确的解析参数调用main(argv)
或my_main_running_function(argv)
函数。
答案 1 :(得分:72)
它只是一个非常快速的包装器,它处理标志解析,然后发送到您自己的主服务器。请参阅code。
答案 2 :(得分:5)
tf.app
没有什么特别之处。这只是generic entry point script,
使用可选的' main'运行程序。功能和' argv'列表。
它与神经网络无关,只是调用main函数,通过任何参数传递给它。
答案 3 :(得分:3)
简单来说, tf.app.run()
的工作是首先设置全局标志供以后使用,如:
from tensorflow.python.platform import flags
f = flags.FLAGS
然后使用一组参数运行自定义主函数。
例如在TensorFlow NMT代码库中,此时开始进行训练/推理的程序执行的第一个入口点(参见下面的代码)
if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
使用argparse
解析参数后,使用tf.app.run()
运行“main”函数,其定义如下:
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
因此,在设置全局使用的标志后, tf.app.run()
只需运行您传递给它的main
函数,并将argv
作为参数。< / p>
PS:正如Salvador Dali's answer所说,这只是一个很好的软件工程实践,我想,虽然我不确定TensorFlow是否执行main
函数的任何优化运行,而不是使用正常运行CPython的。
答案 4 :(得分:1)
Google代码在很大程度上取决于在库/二进制文件/ python脚本中正在访问的全局标志,因此tf.app.run()解析出这些标志以在FLAGs(或类似的东西)变量中创建全局状态,然后调用python main()应该。
如果没有对此tf.app.run()的调用,则用户可能会忘记进行FLAG解析,从而导致这些库/二进制文件/脚本无法访问所需的FLAG。
答案 5 :(得分:1)
2.0兼容答案:如果要在 tf.app.run()
中使用Tensorflow 2.0
,则应使用命令
tf.compat.v1.app.run()
,也可以使用tf_upgrade_v2
将1.x
代码转换为2.0
。