为什么“ tf.scatter_nd_update”会更改数据类型?

时间:2019-03-18 22:43:04

标签: python tensorflow

在tensorflow中,我想创建一个变量向量进行训练。创建变量向量时,我使用了另一个张量的条件,因此我使用“ tf.scatter_nd_update”更新了原始变量向量。 不幸的是,我发现由“ tf.scatter_nd_update”更新的变量从来都不是变量并且不能被训练。 就像下面的教程代码一样,

a = tf.Variable([1,2,3,4,5,67,7],dtype=tf.float32)
indices = tf.constant([[1],[3]])
updates = tf.constant([22,22],dtype=tf.float32)
a_update = tf.scatter_nd_update(a,indices,updates)
print(a)
print(a_update)

输出为:

<tf.Variable 'Variable_8:0' shape=(7,) dtype=float32_ref>
Tensor("ScatterNdUpdate_9:0", shape=(7,), dtype=float32_ref)

很显然,a_update绝不是变量。如果我将其更改为

a_update_variable = tf.Variable(a_update)

它将引发类似

的错误
ValueError: Input 'ref' passed float expected ref type while building 
NodeDef 'Variable_9/ScatterNdUpdate_9_Variable_9_0' using 
Op<name=ScatterNdUpdate; signature=ref:Ref(T), indices:Tindices, updates:T - 
> output_ref:Ref(T); attr=T:type; attr=Tindices:type,allowed=[DT_INT32, 
DT_INT64]; attr=use_locking:bool,default=true>

如何解决该问题?我希望我可以在使用“ tf.scatter_nd_update”之后得到可训练的变量。非常感谢!

0 个答案:

没有答案