我正在尝试在PyTorch中实现领域对抗神经网络。我制作了数据集和数据加载器,如下所示:
``import h5py as h5
from torch.utils import dataclass MyDataset(data.Dataset):
def __init__(self, root, transform=None):
self.root = h5py.File(root, 'r')
self.labels = self.root.get('train').get('targets')[()]
self.data = self.root.get('train').get('inputs')[()]
self.transform = transform
def __getitem__(self, index):
datum = self.data[index]
if self.transform is not None:
datum = self.transform(datum)
return datum, self.labels[index]
def __len__(self):
return len(self.labels)
def close(self):
self.root.close()``
然后,我使用github上可用的代码对网络进行了训练:
n_epochs = 10
for epoch_idx in range(n_epochs):
print(f'Epoch {epoch_idx+1:04d} / {n_epochs:04d}', end='\n=================\n')
dl_source_iter = iter(dl_source)
dl_target_iter = iter(dl_target)
for batch_idx in range(max_batches):
optimizer.zero_grad()
#i=0
#while i<max_batches:
# Training progress and GRL lambda
p = float(batch_idx+epoch_idx * max_batches) / (n_epochs * max_batches)
# Calculating lambda
grl_lambda = 2. / (1. + np.exp(-10 * p)) - 1
# Train on source domain
#taking images and labels from source domain
inputs, targets = (dl_source_iter)
# generate source domain labels
targets = torch.zeros(batch_size, dtype=torch.long)
#Feeding model images and lambda parameter
# Getting Prediction for the class and domain
class_pred, domain_pred = model(X_s, grl_lambda)
#Calculating class (labels) loss for the source domain
loss_s_label = loss_fn_class(class_pred, y_s)
#Calculating domain loss for the source data
loss_s_domain = loss_fn_domain(domain_pred, y_s_domain)
# Train on target domain
#Ignoring the labels for the target domain
X_t, _ = next(dl_target_iter)
#Getting domain labeels for target
y_t_domain = torch.ones(len(X_t), dtype=torch.long)
# Getting domain predictions for the target data
_, domain_pred = model(X_t, grl_lambda)
#Calculating the domain loss for target data
loss_t_domain = loss_fn_domain(domain_pred, y_t_domain)
# Calculating total loss
loss = loss_t_domain + loss_s_domain + loss_s_label
loss.backward()
optimizer.step()
print(f'[{batch_idx+1}/{max_batches}] ' f'class_loss: {loss_s_label.item():.4f} ' f's_domain_loss:
{loss_s_domain.item():.4f} ' f't_domain_loss: {loss_t_domain.item():.4f} ' f'grl_lambda:
{grl_lambda:.3f} '
)
ValueError Traceback (most recent call last)
<ipython-input-210-eb295f5d5052> in <module>()
18 # Train on source domain
19 #taking images and labels from source domain
---> 20 inputs, targets = (dl_source_iter)
21 # generate source domain labels
22 targets = torch.zeros(batch_size, dtype=torch.long)
6 frames
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in to_tensor(pic)
44 if _is_numpy(pic) and not _is_numpy_image(pic):
---> 45 raise ValueError('pic should be 2/3 dimensional. Got {}
dimensions.'.format(pic.ndim))
46
47 if isinstance(pic, np.ndarray):
ValueError:图片应为2/3维。有1个尺寸。
请指导我,如何删除此错误? 问候