我有设置tensorflow图的现有代码。这段代码具有如下所示的类:
class NodeXYZ:
# ....
self.W_xyz = W_var
self.b_xyz = b_var
self.params = [W_var, b_var]
# ...
output = self.W_xyz @ x + self.b_xyz
class NodeABC:
# ...
self.W_abc = W_var
self.b_abs = b_var
self.params = [W_var, b_var]
# ...
output = self.W_abx @x + self.b_abc
如您所见,每个节点都有自己的参数,名称不同。它们仅公开一个列表,但列表不包含属性,而是tensorflow变量。
我想通过编辑Node的实例来干扰这些变量(按比例缩放)???在使用之前。这段代码应该适用于所有节点,因此我不想对名称进行硬编码(在任何情况下,这都是不好的做法。可以放心地假设会频繁创建新的节点)。
一个丑陋的解决方案是使用内省。我可以列出node
的所有属性,检查该属性是否也出现在node.params
中,如果确实存在,则可以覆盖它。
如果可能的话,我想避免内省。我可以使用tensorflow进行上述操作吗? 我找到了tf.assign,但是只复制了一次值。我想插入一个全新的变量:
x = tf.Variable([1.])
r = tf.random_normal(shape=(1,))
assignment_op = tf.assign(x, r)
sess = tf.Session()
sess.run(x) # 1
sess.run(assignment_op)
sess.run(x) # -0.410223
sess.run(x) # -0.410223 arrgh
我尝试的第一件事是一种麻木的方式:
x[:] = r
但是会产生一条错误消息Variable' object does not support item assignment
。也许我想做的事根本不可能吗?