我有:
celery -A tasks worker --loglevel=info -E
这适用于第一次迭代,但是随后我遇到了下一次迭代的错误:
context = torch.tensor(context, dtype=torch.long, device=self.device)
context = context.unsqueeze(0)
generated = context
with torch.no_grad():
past_outputs = None
for i in trange(num_words):
print(i, num_words)
inputs = {"input_ids": generated}
outputs, past_outputs = self.model(
**inputs,
past=past_outputs
)
next_token_logits = outputs[
0, -1, :] / (temperature if temperature > 0 else 1.0)
# reptition penalty from CTRL
# (https://arxiv.org/abs/1909.05858)
for _ in set(generated.view(-1).tolist()):
next_token_logits[_] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(
next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: # greedy sampling:
next_token = torch.argmax(filtered_logits).unsqueeze(0)
else:
next_token = torch.multinomial(
F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat(
(generated, next_token.unsqueeze(0)), dim=1)
我在做错什么吗?
答案 0 :(得分:3)
我相信问题是context
包含超过词汇量的整数值。我的假设是基于最后的追溯线:
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
IndexError: index out of range in self
答案 1 :(得分:0)
我做到了:
outputs, past_outputs = self.models[model_name](
context,
past=past_outputs
)
context = next_token.unsqueeze(0)