我有3个张量流数组(a
,b
,valid_entries
),它们共享前两个维度[T, N, ?]
。这些数组“ valid_entries”之一的形状为[T,N,1]
,具有布尔值。我想对T*M
个2元组索引(M < N
)进行随机抽样,以便对所有这些索引使用valid_entries[t,m] == 1
。
换句话说,对于每个时间步长,我想从a
和b
中随机选择M个有效条目。
我认为在numpy中,可以通过执行以下操作来解决此任务(为简单起见,我们跳过第一维T):
M = 3
N = 5
valid_entries = [[0],[1],[0],[1],[0]]
valid_indices = np.where(a==1)
valid_indices = np.random.select(valid_indices,np.min(len(valid_indices),M))
a_new = a[valid_indices]
b_new = b[valid_indices]
valid_new = valid_entries[valid_indices]
但是,所有这些都需要在Tensorflow中发生。
在此先感谢您的帮助!
答案 0 :(得分:1)
以下是执行此操作的功能:
import tensorflow as tf
def sample_indices(valid, m, seed=None):
valid = tf.convert_to_tensor(valid)
n = tf.size(valid)
# Flatten boolean tensor
valid_flat = tf.reshape(valid, [n])
# Get flat indices where the tensor is true
valid_idx = tf.boolean_mask(tf.range(n), valid_flat)
# Shuffled valid indices
valid_idx_shuffled = tf.random.shuffle(valid_idx, seed=seed)
# Pick sample from shuffled indices
valid_idx_sample = valid_idx_shuffled[:m]
# Unravel indices
return tf.transpose(tf.unravel_index(valid_idx_sample, tf.shape(valid)))
with tf.Graph().as_default(), tf.Session() as sess:
valid = [[ True, True, False, True],
[False, True, True, False],
[False, True, False, False]]
m = 4
print(sess.run(sample_indices(valid, m, seed=0)))
# [[1 1]
# [1 2]
# [0 1]
# [2 1]]
此sample_indices
对于任何形状的布尔张量都是通用的。如果您的情况valid_entries
的形状为(T, N, 1)
,则将获得形状为(M, 3)
的张量作为输出,尽管您可以忽略最后一列,因为它总是为零(或可以改为通过tf.squeeze(valid_entries, axis=2)
。
注意:最后一个tf.transpose
只是输出形状为(sample_size, num_dimensions)
的张量,而不是相反的样子。但是,如果m
很大,并且您不介意维度的顺序,则可以跳过它以节省一些时间和内存,因为tf.transpose
(与NumPy对应)不一样。整个新张量。