谁能举一个小小的例子来解释tf.random.categorical的参数?

时间:2019-03-08 12:19:47

标签: python tensorflow

tensorflow的站点提供了此示例

tf.random.categorical(tf.log([[10., 10.]]), 5)

生成一个张量“具有[1,5]形状,其中每个值是0或1的概率相等”

我已经知道tf.log([[10., 10.]])的基本含义,demo

我想知道[batch_size,num_classes]是做什么的,有人可以举一个小小的例子来解释这些参数吗?

2 个答案:

答案 0 :(得分:1)

您注意到,tf.random.categorical具有两个参数:

  • logits,形状为[batch_size, num_classes]的2D浮点张量。
  • num_samples和整数标量。

输出是形状为[batch_size, num_samples]的2D整数张量。

logits张量(logits[0, :]logits[1, :],...)的每个“行”代表不同categorical distribution的事件概率。但是,该函数并不期望实际的概率值,而是期望的非标准化对数概率;因此实际的机率将是softmax(logits[0, :])softmax(logits[1, :])等。这样做的好处是,您可以基本上给出任何实际值作为输入(例如神经网络的输出),并且它们将是有效的。同样,使用对数使用特定的概率值或比例也很简单。例如,[log(0.1), log(0.3), log(0.6)][log(1), log(3), log(6)]代表相同的概率,其中第二类的概率是第一个类的三倍,但只有第三类的一半。

对于(未归一化的对数)概率的每一行,您从分布中获得num_samples个样本。每个样本都是0num_classes - 1之间的整数,根据给定的概率绘制。因此,结果是形状为[batch_size, num_samples]的2D张量,其中每个分布都有采样的整数。

编辑:该函数的一个小例子。

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    tf.random.set_random_seed(123)
    logits = tf.log([[1., 1., 1., 1.],
                     [0., 1., 2., 3.]])
    num_samples = 30
    cat = tf.random.categorical(logits, num_samples)
    print(sess.run(cat))
    # [[3 3 1 1 0 3 3 0 2 3 1 3 3 3 1 1 0 2 2 0 3 1 3 0 1 1 0 1 3 3]
    #  [2 2 3 3 2 3 3 3 2 2 3 3 2 2 2 1 3 3 3 2 3 2 2 1 3 3 3 3 3 2]]

在这种情况下,结果是一个包含两行30列的数组。第一行中的值是从分类分布中抽样的,其中每个类别([0, 1, 2, 3])的概率相同。在第二行中,类别3是最可能的类别,类别0几乎没有被采样的可能性。

答案 1 :(得分:0)

希望这个简单的例子会有所帮助。

tf.random.categorical 需要两个参数:

  • logits,形状为 [batch_size, num_classes]
  • num_samples

例如:

list_indices.shape = (4, 10)

sampled_indices = tf.random.categorical(list_indices, num_samples=1)

sample_indices 将是

tf.Tensor(
[[2]
 [9]
 [4]
 [7]], shape=(4, 1), dtype=int64)

这意味着从 1 num_samples 中为每行 10 行 (num_classes) 取 4 batch_size