我有一个查询,我在github上在线阅读代码,遇到了这一行,
pi = tf.squeeze(tf.multinomial(logits,1), axis=1)
此处logits是MLP的log softmax输出。这里到底发生了什么?
以下是该功能的完整代码:
def mlp_categorical_policy(x, a, hidden_sizes, activation, output_activation, action_space):
act_dim = action_space.n
logits = mlp(x, list(hidden_sizes)+[act_dim], activation, None)
logp_all = tf.nn.log_softmax(logits)
pi = tf.squeeze(tf.multinomial(logits,1), axis=1)
logp = tf.reduce_sum(tf.one_hot(a, depth=act_dim) * logp_all, axis=1)
logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=act_dim) * logp_all, axis=1)
return pi, logp, logp_pi