我有一批带有不同数量的1和0的“布尔”掩码。对于该批次中的每个样本,我希望从包含“ 1”的索引中获取固定数量的索引采样。
假设我要抽样6个指标。
mask = np.array([[1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1], # 5 ones
[0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1], # 6 ones
[1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]]) # 7 ones
idc = tf.constant(mask)
idc = tf.where(idc > 0)
# ...
# subsampled_idc = tf.some_tf_magic(mask, 6)
# or
# subsampled_idc = tf.some_other_tf_magic(idc, 6)
# ...
with tf.Session() as sess:
print(sess.run(subsampled_idc))
idc
中的tf.where
将包含
[[ 0 0]
[ 0 5]
[ 0 7]
[ 0 9]
[ 0 10]
[ 1 2]
[ 1 3]
[ 1 5]
[ 1 7]
[ 1 9]
[ 1 10]
[ 2 0]
[ 2 1]
[ 2 3]
[ 2 5]
[ 2 7]
[ 2 9]
[ 2 10]]
我更需要的是
[[ 0 5 7 9 10 7] # randomly using one twice because 5 < 6
[ 2 3 5 7 9 10] # using all of them because 6 = 6
[ 0 1 3 5 9 10]] # randomly leaving out one index because 7 > 6
如果批次的“行”包含少于所需样本的“ 1”,则应进行子采样。如果批处理中的“行”包含更多的“ 1”,则应全部使用并添加一些重复项。
不需要排序,实际上改组索引会更好。