我在看codes online,但想知道以下代码是做什么的?
def get_batch(self, er_vocab, er_vocab_pairs, idx):
batch = er_vocab_pairs[idx:idx+self.batch_size]
targets = np.zeros((len(batch), len(d.entities)))
for idx, pair in enumerate(batch):
targets[idx, er_vocab[pair]] = 1.
targets = torch.FloatTensor(targets)
if self.cuda:
targets = targets.cuda()
return np.array(batch), targets
因此,据我了解,batch = er_vocab_pairs[idx:idx+self.batch_size]
从er_vocab_pairs中获得了一批元素的子集。目标表列出了某种指标。但是我们应该理解其余的代码。这不容易理解。