如何将tf.scatter_nd与多维张量一起使用

时间:2019-07-10 11:25:01

标签: python tensorflow keras

我正在尝试创建一个新的张量(output),并根据updates张量放置另一个张量(idx)的值。 output的形状应为[batch_size, 1, 4, 4](类似于2x2像素和一个通道的图像),并且update的形状应为[batch_size, 3]

我已经阅读了Tensorflow文档(我正在使用gpu版本1.13.1),发现tf.scatter_nd应该可以解决我的问题。问题是我无法使其正常工作,我认为在理解如何安排idx时遇到问题。

让我们考虑batch_size = 2,所以我正在做的是:

updates = tf.constant([[1, 2, 3], [4, 5, 6]])  # shape [2, 3]
output_shape = tf.constant([2, 1, 4, 4])
idx = tf.constant([[[1, 0], [1, 1], [1, 0]], [[0, 0], [0, 1], [0, 2]]])  # shape [2, 3, 2]
idx_expanded = tf.expand_dims(idx, 1)  # so I have shape [2, 1, 3, 2]
output = tf.scatter_nd(idx_expanded, updates, output_shape)

我希望它能正常工作,但是没有,它给了我这个错误:

ValueError: The outer 3 dimensions of indices.shape=[2,1,3,2] must match the outer 3 dimensions of updates.shape=[2,3]: Shapes must be equal rank, but are 3 and 2 for 'ScatterNd_7' (op: 'ScatterNd') with input shapes: [2,1,3,2], [2,3], [4]

我不明白为什么期望updates具有3维。我认为idx必须与output_shape有意义(这就是我使用expand_dims的原因),并且也可以使用updates(为三个点指定两个索引),但是很明显我在这里缺少了一些东西。

任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:0)

我一直在使用该函数,但发现了自己的错误。如果有人遇到这个问题,这就是我要解决的问题:

考虑let alertController = UIAlertController(title: "Language".localized(), message: "To changing language you need to restart application, do you want to restart?".localized(), preferredStyle: .alert) let okAction = UIAlertAction(title: "Yes".localized(), style: UIAlertActionStyle.default) { UIAlertAction in NSLog("OK Pressed") exit(0) } batch_size=2点,3张量必须具有形状idx,其中第一维对应于我们从中获取[2, 3, 4]值的批次,第二维必须等于update的第二维(每批次的点数),而第三维则是updates,因为我们需要4索引:[batch_number,channel,row, col]。遵循问题中的示例:

4

这样,就有可能在新的张量中放置特定的数字。