阐明方法:def get_batch(self,er_vocab,er_vocab_pairs,idx):

时间:2020-04-23 15:24:20

标签: pytorch

我在看codes online,但想知道以下代码是做什么的?

def get_batch(self, er_vocab, er_vocab_pairs, idx):
    batch = er_vocab_pairs[idx:idx+self.batch_size]
    targets = np.zeros((len(batch), len(d.entities)))
    for idx, pair in enumerate(batch):
        targets[idx, er_vocab[pair]] = 1.
    targets = torch.FloatTensor(targets)
    if self.cuda:
        targets = targets.cuda()
    return np.array(batch), targets

因此,据我了解,batch = er_vocab_pairs[idx:idx+self.batch_size]从er_vocab_pairs中获得了一批元素的子集。目标表列出了某种指标。但是我们应该理解其余的代码。这不容易理解。

0 个答案:

没有答案