给定形状(2,3,4)的基础数组X,其可以被解释为每组3个元素的两组,其中每个元素是4维的,我想以下面的方式从该数组X中采样。 从2组中的每一组我想要选择每个由长度为3的二进制数组定义的2个子集,其他子集将被设置为0.因此采样过程由形状数组(2,2,3)定义。此采样的结果应具有形状(2,2,3,4)。
这里是执行我需要的代码,但我想知道是否可以使用numpy索引更有效地重写。
import numpy as np
np.random.seed(3)
sets = np.random.randint(0, 10, [2, 3, 4])
subset_masks = np.random.randint(0, 2, [2, 2, 3])
print('Base set\n', sets, '\n')
print('Subset masks\n', subset_masks, '\n')
result = np.empty([2, 2, 3, 4])
for set_index in range(sets.shape[0]):
for subset_index, subset in enumerate(subset_masks[set_index]):
print('----')
picked_subset = subset.reshape(3, 1) * sets[set_index]
result[set_index][subset_index] = picked_subset
print('Picking subset ', subset, 'from set #', set_index)
print(picked_subset, '\n')
输出
Base set
[[[8 9 3 8]
[8 0 5 3]
[9 9 5 7]]
[[6 0 4 7]
[8 1 6 2]
[2 1 3 5]]]
Subset masks
[[[0 0 1]
[1 0 0]]
[[1 0 1]
[0 1 1]]]
----
Picking subset [0 0 1] from set # 0
[[0 0 0 0]
[0 0 0 0]
[9 9 5 7]]
----
Picking subset [1 0 0] from set # 0
[[8 9 3 8]
[0 0 0 0]
[0 0 0 0]]
----
Picking subset [1 0 1] from set # 1
[[6 0 4 7]
[0 0 0 0]
[2 1 3 5]]
----
Picking subset [0 1 1] from set # 1
[[0 0 0 0]
[8 1 6 2]
[2 1 3 5]]
答案 0 :(得分:1)
通过在最后一个轴上添加4D
的新轴并将subset_masks
添加为第二个轴,将每个轴延伸到sets
。要添加这些新轴,我们可以使用None/np.newaxis
。然后,利用NumPy broadcasting
执行逐元素乘法,就像这样 -
subset_masks[...,None]*sets[:,None]
可能只是为了踢,我们也可以使用np.einsum
-
np.einsum('ijk,ilj->iljk',sets,subset_masks)