在tensorflow中,我想创建一个变量向量进行训练。创建变量向量时,我使用了另一个张量的条件,因此我使用“ tf.scatter_nd_update”更新了原始变量向量。 不幸的是,我发现由“ tf.scatter_nd_update”更新的变量从来都不是变量并且不能被训练。 就像下面的教程代码一样,
a = tf.Variable([1,2,3,4,5,67,7],dtype=tf.float32)
indices = tf.constant([[1],[3]])
updates = tf.constant([22,22],dtype=tf.float32)
a_update = tf.scatter_nd_update(a,indices,updates)
print(a)
print(a_update)
输出为:
<tf.Variable 'Variable_8:0' shape=(7,) dtype=float32_ref>
Tensor("ScatterNdUpdate_9:0", shape=(7,), dtype=float32_ref)
很显然,a_update绝不是变量。如果我将其更改为
a_update_variable = tf.Variable(a_update)
它将引发类似
的错误ValueError: Input 'ref' passed float expected ref type while building
NodeDef 'Variable_9/ScatterNdUpdate_9_Variable_9_0' using
Op<name=ScatterNdUpdate; signature=ref:Ref(T), indices:Tindices, updates:T -
> output_ref:Ref(T); attr=T:type; attr=Tindices:type,allowed=[DT_INT32,
DT_INT64]; attr=use_locking:bool,default=true>
如何解决该问题?我希望我可以在使用“ tf.scatter_nd_update”之后得到可训练的变量。非常感谢!