Tensorflow tf.get_variable()索引

时间:2017-01-26 19:10:09

标签: python tensorflow neural-network data-science

下午好。 我继续研究tensorflow,现在停留在重用变量“W”的问题上 以下是代码段:http://pastebin.com/VZETt2ud

我想避免硬编码并从恢复的模型中获取值(而不是10 - get_value()等)。 我在这里读过几个帖子,但到处都只需要整个变量。但是,我不明白如何正确获取,例如,从这里获取数字784:

W = tf.Variable(tf.zeros([784, 10]), name = "W")

我试过了:

idx = tf.constant([0])
temp_var = tf.get_variable("W") 
size_1 = tf.gather(temp_var, idx)

这种方法给了我这个错误: “必须完全定义新变量(W)的形状,但不知道。”

(同样,我避免使用硬编码,不能像[编号,数字]那样写出形状)

我改变了变量的范围,认为它与范围有关,添加了这些行:

with tf.variable_scope("my"):

with tf.variable_scope("my"):
tf.get_variable_scope().reuse_variables()

但是犯了这个错误: “ValueError:变量my / W不存在,或者不是用tf.get_variable()创建的。你是不是想在VarScope中设置reuse = None?” 设置reuse = None后,我仍然遇到同样的问题。

你会这么善良,并建议我如何在这段代码中通过索引获取值?

1 个答案:

答案 0 :(得分:0)

您还需要使用get_variable为第一次访问创建变量。对于后者,您需要设置reuse = True。 像下面这样的东西应该有效:

W = tf.get_variable("W", initializer=tf.zeros([784, 10], dtype=YOUR_DTYPE)
...
temp_var = tf.get_variable("W", reuse=True)
tf.gather(...)