想象一下,我的标签号是100,但是在数据集中,某些数据已损坏,因此我将其一热表示将其设置为0,当我使用“ motor ll_loss”时,我将损坏位置作为值- 1,所以标签数是101,因此出现了错误,请问有什么办法可以使tensorflow正确处理上述损坏数据,而不是清除损坏数据?
logits = F.log_softmax(torch.randn(5, 100), dim=1)
idx_train = torch.as_tensor([1, 2, 3]).long()
idx_train_labels = torch.as_tensor([0, 4, 2]).long()
fail_idx_train_labels = torch.as_tensor([2, 4, 101]).long()
# right
F.nll_loss(logits[idx_train], idx_train_labels)
# RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.
F.nll_loss(logits[idx_train], fail_idx_train_labels)