如何从会话中的图形中获取变量?

时间:2016-11-10 21:56:28

标签: tensorflow

我想检查卷积层的过滤器。为此,我试图在会话的最后一步中获取变量。

以下是我的模型的简化版本:

graph = tf.Graph()
with graph.as_default():

    # Placeholders
    ...

    # Variables.
    conv1_w = tf.Variable(..., name='conv1_w')
    ...

    optimizer = ...
    accuracy = ...

with tf.Session(graph=graph) as session:
    ...
    acc, c1 = session.run([accuracy, conv1_w], feed_dict=feed_test)

我收到以下异常

Fetch argument <tensorflow.python.ops.variables.Variable object at 0x137890710> cannot be interpreted as a Tensor. (Tensor Tensor("conv1_w:0", shape=(5, 5, 1, 16), dtype=float32_ref) is not an element of this graph.)

但是,如果我应用op,我可以在没有错误的情况下获取结果张量:

    c1_op = tf.mul(conv1_w,1.0)
    optimizer = ...
    accuracy = ...

Tensorflow不允许获取变量吗?

1 个答案:

答案 0 :(得分:0)

变量名和“name”参数都是conv1_w

,这可能有些令人困惑

除此之外,您的计算图中似乎根本没有使用它。这意味着没有数据被推送或没有被使用。

快速简单的例子:

import tensorflow as tf
sess = tf.InteractiveSession()

conv1 = tf.Variable(1)

sess.run(tf.initialize_all_variables())
sess.run(conv1)

结果是1

作为旁注:最好使用tf.get_variable()而不是tf.Variable()