我有一个动态形状的Tensor,X形状= [? ,? ,256,?] 然后我在计算:
argmax = tf.argmax(X, axis=3) # shape [ ?, ?, 256]
然后我想用与X相同的形状计算Y和最大值,所以我试着做以下几点:
Y = tf.scatter_nd(tf.cast(argmax, tf.int32), tf.ones(tf.shape(argmax)), tf.shape(X))
但我收到以下错误:
ValueError:output.shape = [?,?,?,?]的内部-252维度必须与updates.shape = [?,?,256]的内部1维度相匹配:形状必须相等,但是对于'ScatterNd'(op:'ScatterNd'),输入形状为[?,?,256],[?,?,256],[4]为0和1。
答案 0 :(得分:0)
我通过以下方式找到了解决方案:
Y = tf.one_hot(argmax, 2, 1.0, -1.0)