如何从Tensorflow和yoloV3中的协议缓冲区pbfile恢复训练?

时间:2019-07-19 02:25:39

标签: python tensorflow protocol-buffers resuming-training

我可以保存ckpt文件,将图形冻结在pb文件中并用它在图像上进行测试。但是现在,我想从cb文件中恢复训练,对于C ++。所以我有两个子问题,

  1. 如何获取用于训练的所有op_node_names。
  2. 如何从pbfile恢复训练。

使用C ++中的pb文件,我可以保存该pb文件以进行测试。

我所有的代码都在这里:https://github.com/YunYang1994/tensorflow-yolov3

使用 convert_weight.py 中的代码来获取用于保存pb文件的节点名称。

但是它会显示错误提示,例如“ *****不在图中”

for var in tf.global_variables():
var_name = var.op.name
var_name_mess = str(var_name).split('/')
var_shape = var.shape
print("111111111111111111112222222222222222222=> ")
print(var_name_mess[0])
if flag.train_from_coco:
    if var_name_mess[0] in preserve_cur_names: continue
cur_weights_mess.append([var_name, var_shape])
org_weights_num = len(org_weights_mess)
cur_weights_num = len(cur_weights_mess)
if cur_weights_num != org_weights_num:
raise RuntimeError

print('=> Number of weights that will rename:\t%d' % cur_weights_num)
cur_to_org_dict = {}
for index in range(org_weights_num):
org_name, org_shape = org_weights_mess[index]
cur_name, cur_shape = cur_weights_mess[index]
if cur_shape != org_shape:
    print(org_weights_mess[index])
    print(cur_weights_mess[index])
    raise RuntimeError
cur_to_org_dict[cur_name] = org_name
print("3333=> " + str(cur_name).ljust(50) + ' : ' + org_name)

with tf.name_scope('load_save'):
    name_to_var_dict = {var.op.name: var for var in         
tf.global_variables()}
restore_dict = {cur_to_org_dict[cur_name]: name_to_var_dict[cur_name] for         
cur_name in cur_to_org_dict}
load = tf.train.Saver(restore_dict)
save = tf.train.Saver(tf.global_variables())
for var in tf.global_variables():
    print("44444=> " + var.op.name)

您能帮助我在python代码中筛选有用的节点名称以及如何恢复训练。

0 个答案:

没有答案