我发现tf.estimator
需要从检查点文件中将整个模型加载到评估器中,以导出SavedModel。但是,当我们使用ParameterServerStrategy
训练具有大规模嵌入的模型时,由于模型大小可能会超出评估程序的内存,因此可能无法加载模型。那么,在这种情况下,如何使用估算器API导出SaveModel?
with tf_session.Session(config=self._session_config) as session:
if estimator_spec.scaffold.local_init_op is not None:
local_init_op = estimator_spec.scaffold.local_init_op
else:
local_init_op = monitored_session.Scaffold.default_local_init_op()
# This saver will be used both for restoring variables now,
# and in saving out the metagraph below. This ensures that any
# Custom Savers stored with the Scaffold are passed through to the
# SavedModel for restore later.
graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True)
if save_variables and not check_variables:
raise ValueError('If `save_variables` is `True, `check_variables`'
'must not be `False`.')
if check_variables:
try:
graph_saver.restore(session, checkpoint_path)