Pytorch-如果正向操作涉及for循环,则进行迷你批处理

时间:2020-05-02 07:15:04

标签: python machine-learning pytorch

我正在尝试一种注意力自动编码器,该编码器可以执行以下操作:

1)每个输入都是数组的元组(例如,长度为10的元组,其中每个数组的大小为90)

2)编码器对大小从90到30的每个数组进行编码

3)编码后的字符串连接在一起,形成大小为10 x 30 = 300的数组)

4)注意层将encoding_combined压缩为30,并查看要注意哪个数组(在10个数组中)

5)解码器将大小为30的数组解码为90的数组

因此,向前,必须将for循环并入元组中每个数组的编码::

 def forward(self, X):
    x = torch.from_numpy(X[0]).view(1,-1,self.in_features)
    encoded_combined = self.encoder(x)
    encoded_combined = torch.squeeze(encoded_combined)
    for arr in X[1:]:
        x = torch.from_numpy(arr).view(1,-1,self.in_features)
        tail = self.encoder(x)
        tail = torch.squeeze(tail)
        encoded_combined = torch.cat((encoded_combined,tail))
    attn_w = F.softmax(self.attn(encoded_combined), dim = 0)
    encoded_combined = encoded_combined.view(-1, 1)
    attn_applied = torch.mm(attn_w_expanded, encoded_combined)
    decoded = self.decoder(attn_applied.view(1,-1,30))
    return decoded, attn_w

上述向前操作一次只能采样1个样本,训练过程非常缓慢。

。是否有任何方法可以实现此类数据集的迷你批处理/摆脱for循环?

当前,我有一个collat​​e_fcn,如下所示,它一次要采样1个样本,因此我不需要禁用数据加载器的自动批处理。

def collate_fcn_one_sample(data):
  data_batch = [batch[0] for batch in data]
  target_batch = [batch[1] for batch in data]
  for group in data_batch:
      for data in group:
          data = torch.from_numpy(data)
  for target in target_batch:
      target = torch.from_numpy(target)
  return group, target

非常感谢您的帮助!

非常感谢。

0 个答案:

没有答案