在这里,我有一个索引矩阵index
,它是数组a
的索引。张量如下。
import tensorflow as tf
import numpy as np
index = tf.constant([
[ 1, 2, 3,-1,-1],
[ 6, 1, 3,-1,-1],
[ 1, 3,-1, 5, 6],
[-1,-1,-1,-1,-1],
[ 6,-1, 9,-1,-1]
])
a = tf.constant([0,0,0,0,
0,0,0,0,
0,0,0,0,
0,0,0,0,], dtype=np.int32)
我想获得一个数组indexed
,该数组指示那些已被index
索引的数组,如下所示。
indexed = [ 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0]
# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
我知道tf.scatter_nd_update
或tf.scatter_update
可能会有所帮助。但是,我不知道如何处理-1
,它代表无效的索引(仅用于填充长度)。那么,如何获得如上所述的indexed
数组?