使用argmax获得的索引的散布更新张量

时间:2018-06-04 08:29:00

标签: python tensorflow reinforcement-learning

我试图用其他值更新张量的最大值,如下所示:

actions = tf.argmax(output, axis=1)
gen_targets = tf.scatter_nd_update(output, actions, q_value)

我在AttributeError: 'Tensor' object has no attribute 'handle'上收到了错误scatter_nd_update

outputactions是占位符,声明为:

output = tf.placeholder('float', shape=[None, num_action])
reward = tf.placeholder('float', shape=[None])

我做错了什么以及实现这一目标的正确方法是什么?

1 个答案:

答案 0 :(得分:2)

您正在尝试更新类型为output的{​​{1}}的值。占位符是不可变对象,您无法更新占位符的值。您尝试更新的张量应该是变量的类型,例如tf.Variable,以便tf.scatter_nd_update()能够更新其值。 解决此问题的一种方法是创建变量,然后使用tf.assign()将占位符的值分配给变量。由于占位符的其中一个维度为tf.placeholder,并且在运行时可能具有任意大小,因此您可能需要将None validate_shape参数设置为tf.assign(),这样占位符的形状不需要与变量的形状匹配。赋值后,False的形状将与通过占位符提供的对象的实际形状相匹配。

var_output