我正在使用以下方式加载冻结的TensorFlow模型(.pb
)文件,
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile( 'raw_model/model.pb' , 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
sess = tf.Session(graph=detection_graph)
input_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
input_tensor.set_shape(shape=(None, 64, 64, 3))
基本上,我正在尝试更新image_tensor
的形状。更新image_tensor
的形状之后,我想将detection_graph
保存为冻结图(.pb
)。
更新形状后,我像这样写detection_graph
,
tf.train.write_graph( detection_graph , "models" , "model.pb" , as_text=False )
但是,当我解析新创建的冻结图时,看不到image_tensor
的更新形状,
model_path = 'models/model.pb'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile( model_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
sess = tf.Session(graph=detection_graph)
input_tensor = detection_graph.get_tensor_by_name( 'image_tensor:0' )
print( input_tensor.shape )
# The output is -> ( ? , ? , ? , 3 )
形状( ? , ? , ? , 3 )
属于原始冻结图。我需要此形状为( ? , 64 , 64 , 3 )
。
我们如何解析冻结的图(
.pb
文件),更新其中的张量形状,然后再次将其转换为冻结的图?另外,通过这种方式,我可以验证形状是否已更新。