TensorFlow,Python,共享变量,在顶部

时间:2016-09-15 09:22:33

标签: python tensorflow

我在使用Python API共享变量的TensorFlow中遇到了问题。

我已阅读官方文档(https://www.tensorflow.org/versions/r0.10/how_tos/variable_scope/index.html),但我仍然无法弄清楚发生了什么。

我在下面写了一个最小的工作示例来说明问题。

简而言之,我希望以下代码执行以下操作:

1)在创建会话后立即初始化一个变量“fc1 / w”,

2)创建一个npy数组“x_npy”以输入占位符“x”,

3)运行一个操作“y”,它应该意识到已经创建了变量“fc1 / w”,然后使用该变量值(而不是初始化新的值)来计算其输出。

4)请注意,我在“linear”函数的变量范围中添加了标志“,reuse = True”,但这似乎没有帮助,因为我一直收到错误:

ValueError: Variable fc1/w does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

这非常令人困惑,因为如果我要删除标志“,reuse = True”,那么TensorFlow会告诉我该变量确实存在:

ValueError: Variable fc1/w already exists, disallowed. Did you mean to set reuse=True in VarScope?

5)请注意我正在使用更大的代码库,我真的希望能够使用共享变量功能,而不是在不使用可能解决特定问题的共享变量的情况下提出黑客攻击我在下面写的示例代码,但可能不会很好地概括。

6)最后,请注意,我真的希望将图表的创建与评估分开。特别是,我不想在会话范围中使用“tf.InteractiveSession()”或创建“y”,即下面:“使用tf.Session()作为sess:”。

这是我在Stack Overflow上的第一篇文章,我对TensorFlow很新,所以如果问题不完全清楚,请接受我的道歉。无论如何,我很乐意提供更多细节或进一步澄清任何方面。

提前谢谢。

import tensorflow as tf
import numpy as np


def linear(x_, output_size, non_linearity, name):
    with tf.variable_scope(name, reuse=True):
        input_size = x_.get_shape().as_list()[1]
        # If doesn't exist, initialize "name/w" randomly:
        w = tf.get_variable("w", [input_size, output_size], tf.float32,
                            tf.random_normal_initializer())
        z = tf.matmul(x_, w)
        return non_linearity(z)


def init_w(name, w_initializer):
    with tf.variable_scope(name):
        w = tf.get_variable("w", initializer=w_initializer)
        return tf.initialize_variables([w])


batch_size = 1
fc1_input_size = 7
fc1_output_size = 5

# Initialize with zeros
fc1_w_initializer = tf.zeros([fc1_input_size, fc1_output_size])

#
x = tf.placeholder(tf.float32, [None, fc1_input_size])

#
y = linear(x, fc1_output_size, tf.nn.softmax, "fc1")

with tf.Session() as sess:

    # Initialize "fc1/w" with zeros.
    sess.run(init_w("fc1", fc1_w_initializer))

    # Create npy array to feed into placeholder x
    x_npy = np.arange(batch_size * fc1_input_size, dtype=np.float32).reshape((batch_size, fc1_input_size))

    # Run y, and print result.
    print(sess.run(y, dict_feed={x: x_npy}))

1 个答案:

答案 0 :(得分:0)

似乎对tf.variable_scope()的调用找到了变量scope / w,即使你在空会话中运行它也是如此。我已经清理了你的代码以进行演示。

** (Protocol.UndefinedError) protocol Ecto.Queryable not implemented for %{}