按索引过滤并以numpy展平,例如tf.sequence_mask

时间:2019-01-15 12:53:21

标签: python numpy tensorflow masking

我想用索引过滤2D数组,然后仅使用过滤器中的值平整该数组。这几乎是tf.sequence_mask会做的,但我需要在numpy或其他灯光库中使用。

谢谢!

PD: 这是一个示例:

array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]

2 个答案:

答案 0 :(得分:1)

这是使用布尔蒙版并将其应用于展平的array_2d

的一种方法
array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

# Create a boolean mask
mask = np.zeros((array_2d.shape), dtype=bool)

# Change to True for elements to be kept
for i, j in enumerate(array_len):
        mask[i][0:j] = True

expected_output = array_2d.flatten()[mask.flatten()]

输出

array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

答案 1 :(得分:1)

这是一个vectorized解决方案,使用布尔掩码为array_2d编制索引:

array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]]) 
array_len = [6,5,3]

m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0,  1,  2,  3,  4,  5,  8,  9, 10, 11, 12, 21, 22, 21])

详细信息

创建遮罩时,将cumsum放在形状与array_2d相同的ones的ndarray上,并进行行比较以查看哪些元素大于{{ 1}}。

所以第一步是创建以下array_len

ndarray

并使用np.ones(array_2d.shape).cumsum(axis=1) array([[1., 2., 3., 4., 5., 6.], [1., 2., 3., 4., 5., 6.], [1., 2., 3., 4., 5., 6.]]) 进行行比较:

array_len

然后,您只需使用以下方法过滤数组:

~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True, False],
       [ True,  True,  True, False, False, False]])