如何冻结设备特定的已保存模型?

时间:2020-07-09 09:05:20

标签: tensorflow tensorflow-serving

我需要冻结保存的模型以进行投放,但是某些保存的模型是特定于设备的,如何解决此问题?

with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    sess.run(tf.tables_initializer())

    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
    inference_graph_def=tf.get_default_graph().as_graph_def()

    for node in inference_graph_def.node:
        node.device = ''

    frozen_graph_path = os.path.join(frozen_dir, 'frozen_inference_graph.pb')
    output_keys = ['ToInt64', 'ToInt32', 'while/Exit_5']
    output_node_names = ','.join(["%s/%s" % ('NmtModel', output_key) for output_key in output_keys])
    _ = freeze_graph.freeze_graph(
            input_graph=inference_graph_def,
            input_saver=None,
            input_binary=True,
            input_saved_model_dir=saved_model_dir,
            input_checkpoint=None,
            output_node_names=output_node_names,
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=frozen_graph_path,
            clear_devices=True,
            initializer_nodes='')
    logging.info("export frozen_inference_graph.pb success!!!")
Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

Caused by op u'NmtModel/transpose/Rank', defined at:
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 55, in <module>
    absl_app.run(main)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 300, in run
    _run_main(main, args)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/absl/app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 50, in main
    saved_model2frozen(FLAGS.saved_model_dir, FLAGS.frozen_dir)
  File "/home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py", line 16, in saved_model2frozen
    tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_dir)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 197, in load
    return loader.load(sess, tags, import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 350, in load
    **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/saved_model/loader_impl.py", line 278, in load_graph
    meta_graph_def, import_scope=import_scope, **saver_kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1696, in _import_meta_graph_with_return_elements
    **kwargs))
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/meta_graph.py", line 806, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
    return func(*args, **kwargs)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 442, in import_graph_def
    _ProcessNewOps(graph)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 234, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3440, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 3299, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "/home/yongxian.zyx/alitranx4Corp/.venv/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1770, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Cannot assign a device for operation NmtModel/transpose/Rank: Operation was explicitly assigned to /device:GPU:4 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0, /job:localhost/replica:0/task:0/device:GPU:1, /job:localhost/replica:0/task:0/device:XLA_CPU:0, /job:localhost/replica:0/task:0/device:XLA_GPU:0 ]. Make sure the device specification refers to a valid device.
     [[node NmtModel/transpose/Rank (defined at /home/yongxian.zyx/alitranx4Corp/mtprime/transformer_sync/saved_model2frozen.py:16)  = Rank[T=DT_INT64, _device="/device:GPU:4"](NmtModel/Placeholder)]]

似乎某些模型在多GPU中进行了训练,但是在没有清晰的设备信息的情况下导出到保存的模型。

1 个答案:

答案 0 :(得分:2)

我不确定是否有更好的方法来解决此问题,但是一种可能就是简单地编辑保存的模型信息以删除设备规格。尽管您应在使用前备份已保存的模型,以防万一。

from pathlib import Path
import tensorflow as tf
from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

# Read the model file
model_path = saved_model_dir
graph_path = Path(model_path, 'saved_model.pb')
sm = SavedModel()
with graph_path.open('rb') as f:
    sm.ParseFromString(f.read())
# Go through graph and functions to remove every device specification
for mg in sm.meta_graphs:
    for node in mg.graph_def.node:
        node.device = ''
    for func in mg.graph_def.library.function:
        for node in func.node_def:
            node.device = ''
# Write over file
with graph_path.open('wb') as f:
    f.write(sm.SerializeToString())

# Now load model as usual
# ...