在pytorch中训练深度对抗神经网络

时间:2020-07-05 11:35:02

标签: tensorflow pytorch

我正在尝试在PyTorch中实现领域对抗神经网络。我制作了数据集和数据加载器,如下所示:

``import h5py as h5
  from torch.utils import dataclass MyDataset(data.Dataset):
  def __init__(self, root, transform=None):
    self.root = h5py.File(root, 'r')
    self.labels = self.root.get('train').get('targets')[()]
    self.data = self.root.get('train').get('inputs')[()]
    self.transform = transform
  def __getitem__(self, index):
    datum = self.data[index]
    if self.transform is not None:
    datum = self.transform(datum)
    return datum, self.labels[index]
  def __len__(self):
    return len(self.labels)
  def close(self):
  self.root.close()``

然后,我使用github上可用的代码对网络进行了训练:

n_epochs = 10
for epoch_idx in range(n_epochs):
   print(f'Epoch {epoch_idx+1:04d} / {n_epochs:04d}', end='\n=================\n')
dl_source_iter = iter(dl_source)
dl_target_iter = iter(dl_target)
for batch_idx in range(max_batches):
    optimizer.zero_grad()
    #i=0
    #while i<max_batches:
    # Training progress and GRL lambda
    p = float(batch_idx+epoch_idx * max_batches) / (n_epochs * max_batches)       
    # Calculating lambda
    grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1
    # Train on source domain
    #taking images and labels from source domain
    inputs, targets = (dl_source_iter)
    # generate source domain labels
    targets = torch.zeros(batch_size, dtype=torch.long) 
    #Feeding model images and lambda parameter
   # Getting Prediction for the class and domain
   class_pred, domain_pred = model(X_s, grl_lambda)
   #Calculating class (labels) loss for the source domain
   loss_s_label = loss_fn_class(class_pred, y_s)
   #Calculating domain loss for the source data
   loss_s_domain = loss_fn_domain(domain_pred, y_s_domain)
   # Train on target domain
   #Ignoring the labels for the target domain
   X_t, _ = next(dl_target_iter) 
   #Getting domain labeels for target
    y_t_domain = torch.ones(len(X_t), dtype=torch.long) 
    # Getting domain predictions for the target data
   _, domain_pred = model(X_t, grl_lambda)
    #Calculating the domain loss for target data
    loss_t_domain = loss_fn_domain(domain_pred, y_t_domain)
    # Calculating total loss
    loss = loss_t_domain + loss_s_domain + loss_s_label
    loss.backward()
    optimizer.step()       
    print(f'[{batch_idx+1}/{max_batches}] ' f'class_loss: {loss_s_label.item():.4f} ' f's_domain_loss: 
   {loss_s_domain.item():.4f} ' f't_domain_loss: {loss_t_domain.item():.4f} ' f'grl_lambda: 
   {grl_lambda:.3f} '
         )

但是我收到一个奇怪的错误,例如: ) 但是我收到一个奇怪的错误,例如:

ValueError跟踪(最近一次通话) 在()

  ValueError                                Traceback (most recent call last)
 <ipython-input-210-eb295f5d5052> in <module>()
 18         # Train on source domain
 19         #taking images and labels from source domain
---> 20         inputs, targets = (dl_source_iter)
 21         # generate source domain labels
 22         targets = torch.zeros(batch_size, dtype=torch.long)


 6 frames
 /usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in to_tensor(pic)

 44     if _is_numpy(pic) and not _is_numpy_image(pic):
---> 45         raise ValueError('pic should be 2/3 dimensional. Got {} 
dimensions.'.format(pic.ndim))
 46 
 47     if isinstance(pic, np.ndarray):

ValueError:图片应为2/3维。有1个尺寸。

请指导我,如何删除此错误? 问候

0 个答案:

没有答案