使用tf.app.run()和argparse

时间:2018-05-16 12:44:40

标签: python-3.x parsing tensorflow distributed-computing

我已经理解了解析器的功能,但在以下代码中与tf.app.run()混合使用时,我无法使用它:

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.register("type", "bool", lambda v: v.lower() == "true")

    parser.add_argument("--ps_hosts",
                        type=str,
                        default="",
                        help="Comma-seperated list of hostname:port pairs")
    parser.add_argument("--worker_hosts",
                        type=str,
                        default="",
                        help="Comma-seperated list of hostname:port pairs")
    parser.add_argument("--job_name",
                        type=str,
                        default="",
                        help="One of 'ps', 'worker'")
    parser.add_argument("--task_index",
                        type=int,
                        default=0,
                        help="Index of task within the job")
    parser.add_argument("--data_dir",
                        type=str,
                    -   default="/tmp/mnist_data",
                        help="Directory for storing input data")
    parser.add_argument("--log_dir",
                        type=str,
                        default="/tmp/train_logs",
                        help="Directory of train logs")
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

程序中的main函数没有任何参数,因为它被定义为def main(_)。那么argv中的tf.app.run()参数应该是什么意思呢?

由于

1 个答案:

答案 0 :(得分:2)

argv参数用于Tensorflow的内置命令行标志解析。它主要用于演示。您可以定义tf.flags.DEFINE_integer('batch_size', 128)之类的标记。然后,您就可以使用tf.flags.FLAGS.batch_size访问它。

如果您使用ArgumentParser解析参数,则无需使用tf.app.run