graph.get_tensor_by_name和tf.global_variable之间的区别

时间:2018-06-26 06:15:53

标签: tensorflow machine-learning

我可以通过graph.get_tensor_by_name获得张量,但是我无法在tf.global_variable中找到它。 就我而言,我定义了一些tf.Tensor:

output_y = Dense(units=y.shape[1],activation='softmax',kernel_regularizer=regularizers.l2(),bias_regularizer=regularizers.l2(),activity_regularizer=regularizers.l2(),name='output_y_'+str(index))(pretrain_output)
y_tf = tf.placeholder(tf.float32, shape=(None, y.shape[1]),name='y_tf_'+str(index))
loss_tensor = tf.nn.softmax_cross_entropy_with_logits(logits=output_y, labels=y_tf, name='loss_tensor_' + str(index))

我可以如下导出张量形状和名称:

>>output_y
<tf.Tensor 'train_variable/output_y_0/Softmax:0' shape=(?, 4) dtype=float32>
>>y_tf
<tf.Tensor 'train_variable/y_tf_0:0' shape=(?, 4) dtype=float32>
>>loss_tensor
<tf.Tensor 'train_variable/loss_tensor_0/Reshape_2:0' shape=(?,) dtype=float32>

此外,我可以使用tf.get_default_graph.get_tensor_by_name来检索张量:

>>tf.get_default_graph().get_tensor_by_name('train_variable/output_y_0/Softmax:0')
<tf.Tensor 'train_variable/output_y_0/Softmax:0' shape=(?, 4) dtype=float32>
>>tf.get_default_graph().get_tensor_by_name('train_variable/y_tf_0:0')
<tf.Tensor 'train_variable/y_tf_0:0' shape=(?, 4) dtype=float32>
>>tf.get_default_graph().get_tensor_by_name('train_variable/loss_tensor_0/Reshape_2:0')
<tf.Tensor 'train_variable/loss_tensor_0/Reshape_2:0' shape=(?,) dtype=float32>

但是,在tf.global_variables()中找不到这些变量名。看来tf.global_variables()仅包含参数变量,例如kernel / bias。现在我必须记住张量名称才能检索对象输出(在我的情况下为output_y)。有人可以告诉我如何检索张量,例如在具有所有张量的列表中搜索它吗?

1 个答案:

答案 0 :(得分:0)

节点的 read 操作的张量和作为变量的张量之间存在差异。

变量由一个值和几个操作组成:

import tensorflow as tf
a = tf.get_variable('a', tf.float32)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

sess.run(a)  # gives 42.
sess.run(tf.get_default_graph().get_tensor_by_name('a/read:0'))  # gives 42. as well
print(a.op.outputs)  # <tf.Tensor 'a:0' shape=() dtype=float32_ref>]

它的行为类似:

>>> type(a)
<class 'tensorflow.python.ops.variables.Variable'>
>>> type(tf.get_default_graph().get_tensor_by_name('a/read:0'))
<class 'tensorflow.python.framework.ops.Tensor'>

但是它们是不同的。

最简单的方法是返回output_y,以防万一您再次需要它。否则,请遵循: https://stackoverflow.com/a/36893840/7443104