在拥抱面数据集上迭代 DataLoader 时获取批次索引

时间:2021-07-01 11:52:16

标签: nlp pytorch huggingface-transformers pytorch-dataloader huggingface-datasets

下面的代码来自一个 tutorial 的拥抱脸:

from datasets import load_metric

metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

在循环 for batch in eval_dataloader: 中,我如何知道该批次包含数据集中的哪些索引?

DataLoader 是之前使用

创建的
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

请注意,它没有改组标志,因此可以使用批量大小手动计数,但是如何使用改组进行计数?是否可以在创建数据集和数据加载器时使其成为批处理的字段?

0 个答案:

没有答案