我怀疑是否可以轻松地将此脚本转换为在没有此错误的Jupyter笔记本中运行。函数tf.app.run()提供了一个处理标志解析的包装器。但看起来张量流代码强制退出主要完成后运行的进程。
这是我的主要功能:
def main(_):
input_fn = make_input_fn
hparams = tf.contrib.training.HParams(
learning_rate=.1,
)
config = tf.ConfigProto(
# allow_soft_placement=True,
# log_device_placement=True
)
trainingConfig = tf.contrib.learn.RunConfig(
save_summary_steps=500,
save_checkpoints_steps=500,
model_dir=("/tmp/tf-logs/bucketized-01"),
session_config=config
)
estimator = tf.estimator.Estimator(
model_fn=make_model,
params=hparams,
config=trainingConfig
)
estimator.train(
input_fn=input_fn,
steps=TRAIN_EPOCHS,
)
当我打电话给Jupyter笔记本时:
if __name__ == '__main__':
tf.app.run(main)
我遇到了这个错误:
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from /tmp/tf-logs/bucketized-01/model.ckpt-2001
INFO:tensorflow:Saving checkpoints for 2002 into /tmp/tf-logs/bucketized-01/model.ckpt.
INFO:tensorflow:loss = 11.734686, step = 2002
INFO:tensorflow:global_step/sec: 4.84241
INFO:tensorflow:loss = 11.320501, step = 2102 (20.653 sec)
INFO:tensorflow:global_step/sec: 5.54159
INFO:tensorflow:loss = 9.874545, step = 2202 (18.044 sec)
INFO:tensorflow:global_step/sec: 5.20988
INFO:tensorflow:loss = 11.533301, step = 2302 (19.196 sec)
INFO:tensorflow:Saving checkpoints for 2401 into /tmp/tf-logs/bucketized-01/model.ckpt.
INFO:tensorflow:Loss for final step: 10.57784.
An exception has occurred, use %tb to see the full traceback.
SystemExit
答案 0 :(得分:1)
这里有完整的tf.app.run
功能:
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or _sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
是的,它明确地调用sys.exit()
,因此它不应该在Jupyter中使用。如果您只需要标记解析,只需调用flags.FLAGS._parse_flags(args=args)
或使用此版本:
import sys
from tensorflow.python.platform import flags
def run(main=None, argv=None):
args = argv[1:] if argv else None
flags_passthrough = flags.FLAGS._parse_flags(args=args)
main = main or sys.modules['__main__'].main
main(sys.argv[:1] + flags_passthrough)