给定布尔掩码,对固定数量的索引进行二次采样

时间:2019-08-23 06:48:33

标签: tensorflow

我有一批带有不同数量的1和0的“布尔”掩码。对于该批次中的每个样本,我希望从包含“ 1”的索引中获取固定数量的索引采样。

假设我要抽样6个指标。

  mask = np.array([[1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1],  # 5 ones
                   [0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1],  # 6 ones
                   [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]]) # 7 ones
  idc = tf.constant(mask)
  idc = tf.where(idc > 0)

  # ...
  # subsampled_idc = tf.some_tf_magic(mask, 6)
  # or
  # subsampled_idc = tf.some_other_tf_magic(idc, 6)
  # ...

  with tf.Session() as sess:
    print(sess.run(subsampled_idc))

idc中的tf.where将包含

[[ 0  0]
 [ 0  5]
 [ 0  7]
 [ 0  9]
 [ 0 10]
 [ 1  2]
 [ 1  3]
 [ 1  5]
 [ 1  7]
 [ 1  9]
 [ 1 10]
 [ 2  0]
 [ 2  1]
 [ 2  3]
 [ 2  5]
 [ 2  7]
 [ 2  9]
 [ 2 10]]

我更需要的是

[[ 0  5  7  9  10  7]  # randomly using one twice because 5 < 6
 [ 2  3  5  7  9  10]  # using all of them because 6 = 6
 [ 0  1  3  5  9  10]] # randomly leaving out one index because 7 > 6

如果批次的“行”包含少于所需样本的“ 1”,则应进行子采样。如果批处理中的“行”包含更多的“ 1”,则应全部使用并添加一些重复项。

不需要排序,实际上改组索引会更好。

0 个答案:

没有答案