如何将protobuf图转换为二进制线格式?

时间:2016-07-24 05:51:36

标签: python python-2.7 tensorflow protocol-buffers

我有一种将二进制有线格式转换为人类可读格式的方法,但我无法做到这个

的反转
import tensorflow as tf
from tensorflow.python.platform import gfile

def converter(filename): 
  with gfile.FastGFile(filename,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pb', as_text=True)
  return

我只需为此输入文件名即可。但在做相反的事情我得到了

  File "pb_to_pbtxt.py", line 16, in <module>
    converter('protobuf.pb')  # here you can write the name of the file to be converted
  File "pb_to_pbtxt.py", line 11, in converter
    graph_def.ParseFromString(f.read())
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/message.py", line 185, in ParseFromString
    self.MergeFromString(serialized)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1008, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/python_message.py", line 1034, in InternalParse
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 868, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/usr/local/lib/python2.7/dist-packages/google/protobuf/internal/decoder.py", line 838, in _RaiseInvalidWireType
    raise _DecodeError('Tag had invalid wire type.')

2 个答案:

答案 0 :(得分:5)

您可以使用google.protobuf.text_format模块执行反向翻译:

import tensorflow as tf
from google.protobuf import text_format

def convert_pbtxt_to_graphdef(filename):
  """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.

  Args:
    filename: The name of a file containing a GraphDef pbtxt (text-formatted
      `tf.GraphDef` protocol buffer data).

  Returns:
    A `tf.GraphDef` protocol buffer.
  """
  with tf.gfile.FastGFile(filename, 'r') as f:
    graph_def = tf.GraphDef()

    file_content = f.read()

    # Merges the human-readable string in `file_content` into `graph_def`.
    text_format.Merge(file_content, graph_def)
  return graph_def

答案 1 :(得分:2)

您可以使用tf.Graph.as_graph_def()然后使用Protobuf的SerializeToString(),如下所示:

proto_graph = # obtained by calling tf.Graph.as_graph_def()

with open("my_graph.bin", "wb") as f:
    f.write(proto_graph.SerializeToString())

如果您只想编写文件而不关心编码,也可以使用tf.train.write_graph()

v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')

注意:经过TF 0.10测试,不确定早期版本。