我编写了代码以提取权重,以将值分配给模型中的某些张量。
from tensorflow.python import pywrap_tensorflow
vars_weights = {}
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
vars_weights[key] = reader.get_tensor(key)
但是,key
就是这样
'm_0/net_1/dense_2/kernel/Adam'
'm_0/net_1/dense/bias/Adam'
'm_1/net_1/dense_2/bias/Adam'
但是,张量名称从
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
与上面的名称不同。
'm_0/net_1/dense_2/kernel/Adam:0'
'm_0/net_1/dense/bias/Adam:0'
'm_1/net_1/dense_2/bias/Adam:0'
它返回张量的真实名称。如何通过在代码中返回张量的真实名称来修复它?