为什么在logits上使用tf.multinomial?对数softmax输入的确切输出是什么?

时间:2019-01-21 06:41:34

标签: tensorflow

我有一个查询,我在github上在线阅读代码,遇到了这一行,

pi = tf.squeeze(tf.multinomial(logits,1), axis=1)

此处logits是MLP的log softmax输出。这里到底发生了什么?

以下是该功能的完整代码:

def mlp_categorical_policy(x, a, hidden_sizes, activation, output_activation, action_space):
    act_dim = action_space.n
    logits = mlp(x, list(hidden_sizes)+[act_dim], activation, None)
    logp_all = tf.nn.log_softmax(logits)
    pi = tf.squeeze(tf.multinomial(logits,1), axis=1)
    logp = tf.reduce_sum(tf.one_hot(a, depth=act_dim) * logp_all, axis=1)
    logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=act_dim) * logp_all, axis=1)
    return pi, logp, logp_pi

0 个答案:

没有答案