检查点中的张量名称与tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)不同

时间:2019-03-26 05:48:56

标签: tensorflow

我编写了代码以提取权重,以将值分配给模型中的某些张量。

from tensorflow.python import pywrap_tensorflow

vars_weights = {}
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in sorted(var_to_shape_map):
    vars_weights[key] = reader.get_tensor(key)

但是,key就是这样

'm_0/net_1/dense_2/kernel/Adam'
'm_0/net_1/dense/bias/Adam'
'm_1/net_1/dense_2/bias/Adam'

但是,张量名称从  tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)与上面的名称不同。

'm_0/net_1/dense_2/kernel/Adam:0'
'm_0/net_1/dense/bias/Adam:0'
'm_1/net_1/dense_2/bias/Adam:0'

它返回张量的真实名称。如何通过在代码中返回张量的真实名称来修复它?

0 个答案:

没有答案