如何使用tf.argmax索引正确使用tf.scatter_nd?

时间:2017-10-05 08:15:30

标签: tensorflow

我有一个动态形状的Tensor,X形状= [? ,? ,256,?] 然后我在计算:

argmax = tf.argmax(X, axis=3) # shape [ ?, ?, 256]

然后我想用与X相同的形状计算Y和最大值,所以我试着做以下几点:

Y = tf.scatter_nd(tf.cast(argmax, tf.int32), tf.ones(tf.shape(argmax)), tf.shape(X))

但我收到以下错误:

  

ValueError:output.shape = [?,?,?,?]的内部-252维度必须与updates.shape = [?,?,256]的内部1维度相匹配:形状必须相等,但是对于'ScatterNd'(op:'ScatterNd'),输入形状为[?,?,256],[?,?,256],[4]为0和1。

1 个答案:

答案 0 :(得分:0)

我通过以下方式找到了解决方案:

Y = tf.one_hot(argmax, 2, 1.0, -1.0)