如何使用在ParameterServerStrategy中大规模嵌入的估计器导出SavedModel?

时间:2019-11-08 12:06:03

标签: tensorflow-estimator

我发现tf.estimator需要从检查点文件中将整个模型加载到评估器中,以导出SavedModel。但是,当我们使用ParameterServerStrategy训练具有大规模嵌入的模型时,由于模型大小可能会超出评估程序的内存,因此可能无法加载模型。那么,在这种情况下,如何使用估算器API导出SaveModel?

https://github.com/tensorflow/estimator/blob/8c573ba86938a394e036a0376ea29e302d9534ad/tensorflow_estimator/python/estimator/estimator.py#L951-L969

   with tf_session.Session(config=self._session_config) as session:

        if estimator_spec.scaffold.local_init_op is not None:
          local_init_op = estimator_spec.scaffold.local_init_op
        else:
          local_init_op = monitored_session.Scaffold.default_local_init_op()

        # This saver will be used both for restoring variables now,
        # and in saving out the metagraph below. This ensures that any
        # Custom Savers stored with the Scaffold are passed through to the
        # SavedModel for restore later.
        graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True)

        if save_variables and not check_variables:
          raise ValueError('If `save_variables` is `True, `check_variables`'
                           'must not be `False`.')
        if check_variables:
          try:
            graph_saver.restore(session, checkpoint_path)

0 个答案:

没有答案