我们需要从一个张量采样到另一个包含概率的张量的值。假设我们有两个形状为(?,3)的张量t1,t2,并且想要找到另一个形状为(?,1)的张量t3,其中包含与t2中的概率有关的t1中每一行的样本。
答案 0 :(得分:0)
您可以分两个步骤进行操作。首先,对索引0
,1
和2
进行采样,然后将这些索引替换为张量值。这可以通过tf.random.categorical
完成(有关此功能的更多信息,请参见this question)。注意tf.random.categorical
是在1.13.1版中添加的。
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
tf.random.set_random_seed(0)
values = tf.placeholder(tf.float32, [None, 3])
probabilities = tf.placeholder(tf.float32, [None, 3])
# You can change the number of samples per row (or make it a placeholder)
num_samples = 1
# Use log to get log-probabilities or give logit values (pre-softmax) directly
logits = tf.log(probabilities)
# Sample column indices
col_idx = tf.random.categorical(logits, num_samples, dtype=tf.int32)
# Make row indices
num_rows = tf.shape(values)[0]
row_idx = tf.tile(tf.expand_dims(tf.range(num_rows), 1), (1, num_samples))
# Gather actual values
result = tf.gather_nd(values, tf.stack([row_idx, col_idx], axis=-1))
# Test
print(sess.run(result, feed_dict={
values: [[10., 20., 30.], [40., 50., 60.]],
# Can be actual or "proportional" probabilities not summing 1
probabilities: [[.1, .3, .6], [0., 4., 2.]]
}))
# [[30.]
# [60.]]