我正在使用gpt-2简单软件包:https://github.com/minimaxir/gpt-2-simple
我想获得所有可能的下一个标记的概率作为输出。像这样:
[['A',0.25],['B',0.25],['C',0.25],['D',0.25]]
我已经修改了gpt_2_simple python代码,如下所示:
full_output = sample.sample_sequence(
hparams=hparams,
length=min(length, 1023 - (len(context_tokens) if prefix else 0)),
start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
context=context if prefix else None,
batch_size=batch_size,
temperature=temperature, top_k=top_k, top_p=top_p
)
logit_output = full_output[:,0:]
out = sess.run(output, feed_dict={context: batch_size * [context_tokens]})
logit_out = sess.run(logit_output, feed_dict={context: batch_size * [context_tokens]})
我希望将输出令牌链接到它们的温度对数logit值,然后对其进行解码,以像上面的示例那样获得每个令牌的概率。
有人可以帮助我重新格式化此代码,以便我可以访问输出令牌/对数概率组合吗?
答案 0 :(得分:0)
1)获取已解码令牌的列表:
enc = gpt_2_simple.src.encoder.get_encoder(checkpoint_path)
N_token = len(enc.encoder)
tokens_decoded = [enc.decode([token]) for token in range(N_token)]
2)获得概率:
probs = tf.nn.softmax(logits)