tensorflow:让一个变量指向另一个变量

时间:2018-10-18 17:38:06

标签: python tensorflow

我有设置tensorflow图的现有代码。这段代码具有如下所示的类:

class NodeXYZ:
    # ....
    self.W_xyz = W_var
    self.b_xyz = b_var
    self.params = [W_var, b_var]

    # ...
    output = self.W_xyz @ x + self.b_xyz

class NodeABC:
    # ...
    self.W_abc = W_var
    self.b_abs = b_var
    self.params = [W_var, b_var]

    # ...
    output = self.W_abx @x + self.b_abc

如您所见,每个节点都有自己的参数,名称不同。它们仅公开一个列表,但列表不包含属性,而是tensorflow变量。

我想通过编辑Node的实例来干扰这些变量(按比例缩放)???在使用之前。这段代码应该适用于所有节点,因此我不想对名称进行硬编码(在任何情况下,这都是不好的做法。可以放心地假设会频繁创建新的节点)。

一个丑陋的解决方案是使用内省。我可以列出node的所有属性,检查该属性是否也出现在node.params中,如果确实存在,则可以覆盖它。

如果可能的话,我想避免内省。我可以使用tensorflow进行上述操作吗? 我找到了tf.assign,但是只复制了一次值。我想插入一个全新的变量:

x = tf.Variable([1.])
r = tf.random_normal(shape=(1,))
assignment_op = tf.assign(x, r)

sess = tf.Session()
sess.run(x) # 1
sess.run(assignment_op)
sess.run(x) # -0.410223
sess.run(x) # -0.410223 arrgh

我尝试的第一件事是一种麻木的方式:

x[:] = r

但是会产生一条错误消息Variable' object does not support item assignment。也许我想做的事根本不可能吗?

0 个答案:

没有答案