我正在编写一个脚本,使用tensorflow retrain.py中的main()函数自动进行训练。通常从外壳程序使用已解析的参数调用此脚本。在retrain.py中:
if __name__ == __main__:
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_dir',
type=str,
default='',
help='Path to folders of labeled images.'
)
parser.add_argument(
'--output_graph',
type=str,
default='/tmp/output_graph.pb',
help='Where to save the trained graph.'
)
...
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
我了解到tensorflow通常将FLAGS
参数作为全局变量处理,但我不明白如何将此变量设置为全局变量,因为在代码段中,FLAGS
应该是{{ 1}}对象。
但是,我尝试在自己的脚本中手动定义argparse.Namespace
变量:
FLAGS
始终得到错误from scripts.retrain import main
...
if __name__ == '__main__':
tf.app.flags.DEFINE_string('summaries_dir', summaries_dir, 'Help summaries_dir.')
tf.app.flags.DEFINE_string('image_dir', image_dir, 'Help image_dir.')
...
FLAGS = tf.app.flags.FLAGS
tf.app.run(main=main, argv=[sys.argv[0]] + ['python -m scripts.retrain.py'])
。我应该如何从脚本中运行retrain.py?