Tensorflow:如何在没有np.where的情况下根据条件随机选择元素?

时间:2019-05-02 10:22:37

标签: numpy tensorflow

我有3个张量流数组(abvalid_entries),它们共享前两个维度[T, N, ?]。这些数组“ valid_entries”之一的形状为[T,N,1],具有布尔值。我想对T*M个2元组索引(M < N)进行随机抽样,以便对所有这些索引使用valid_entries[t,m] == 1

换句话说,对于每个时间步长,我想从ab中随机选择M个有效条目。

我认为在numpy中,可以通过执行以下操作来解决此任务(为简单起见,我们跳过第一维T):

M = 3
N = 5
valid_entries = [[0],[1],[0],[1],[0]]
valid_indices = np.where(a==1)
valid_indices = np.random.select(valid_indices,np.min(len(valid_indices),M))
a_new = a[valid_indices]
b_new = b[valid_indices]
valid_new = valid_entries[valid_indices]

但是,所有这些都需要在Tensorflow中发生。

在此先感谢您的帮助!

1 个答案:

答案 0 :(得分:1)

以下是执行此操作的功能:

import tensorflow as tf

def sample_indices(valid, m, seed=None):
    valid = tf.convert_to_tensor(valid)
    n = tf.size(valid)
    # Flatten boolean tensor
    valid_flat = tf.reshape(valid, [n])
    # Get flat indices where the tensor is true
    valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
    # Shuffled valid indices
    valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
    # Pick sample from shuffled indices
    valid_idx_sample = valid_idx_shuffled[:m]
    # Unravel indices
    return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))

with tf.Graph().as_default(), tf.Session() as sess:
    valid = [[ True,  True, False,  True],
             [False,  True,  True, False],
             [False,  True, False, False]]
    m = 4
    print(sess.run(sample_indices(valid, m, seed=0)))
    # [[1 1]
    #  [1 2]
    #  [0 1]
    #  [2 1]]

sample_indices对于任何形状的布尔张量都是通用的。如果您的情况valid_entries的形状为(T, N, 1),则将获得形状为(M, 3)的张量作为输出,尽管您可以忽略最后一列,因为它总是为零(或可以改为通过tf.squeeze(valid_entries, axis=2)

注意:最后一个tf.transpose只是输出形状为(sample_size, num_dimensions)的张量,而不是相反的样子。但是,如果m很大,并且您不介意维度的顺序,则可以跳过它以节省一些时间和内存,因为tf.transpose(与NumPy对应)不一样。整个新张量。