我有一个数据加载器,如下所示,它以批处理大小加载数据(默认为8):
class DataLoaderStego(DataLoader):
def __init__(self, cover_dir, stego_dir, embedding_otf=False,
shuffle=False, pair_constraint=False, batch_size=1,
transform=None, num_workers=0, pin_memory=False):
self.pair_constraint = pair_constraint
self.embedding_otf = embedding_otf
if pair_constraint and batch_size % 2 == 0:
dataset = DatasetPair(cover_dir, stego_dir, embedding_otf,
transform)
_batch_size = int(batch_size / 2)
else:
dataset = DatasetNoPair(cover_dir, stego_dir, embedding_otf,
transform)
_batch_size = batch_size
if pair_constraint:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
else:
sampler = RandomBalancedSampler(dataset)
super(DataLoaderStego, self). \
__init__(dataset, _batch_size, None, sampler,
None, num_workers, pin_memory=pin_memory, drop_last=True)
self.shuffle = shuffle
def __iter__(self):
return DataLoaderIterWithReshape(self)
# if self.pair_constraint:
# return DataLoaderIterWithReshape(self)
# else:
# return DataLoaderIter(self)
此类使用DatasetPair
类创建数据集:
class DatasetPair(Dataset):
def __init__(self, cover_dir, stego_dir, embedding_otf=False,
transform=None):
self.cover_dir = cover_dir
self.stego_dir = stego_dir
self.cover_list = [x.split('/')[-1]
for x in glob(cover_dir + '/*')]
self.transform = transform
self.embedding_otf = embedding_otf
assert len(self.cover_list) != 0, "cover_dir is empty"
# stego_list = ['.'.join(x.split('/')[-1].split('.')[:-1])
# for x in glob(stego_dir + '/*')]
def __getitem__(self, idx):
idx = int(idx)
labels = np.array([0, 1], dtype='int32')
cover_path = os.path.join(self.cover_dir,
self.cover_list[idx])
print("cover path", cover_path)
# cover = Image.open(cover_path)
cover = Image.open(cover_path)
# print("@@@", cover)
images = np.empty((2, cover.size[0], cover.size[1], 1),
dtype='uint8')
print("images", images.shape) # Print image
images[0, :, :, 0] = np.array(cover)
if self.embedding_otf:
images[1, :, :, 0] = np.copy(images[0, :, :, 0])
beta_path = os.path.join(self.stego_dir,
'.'.join(self.cover_list[idx].
split('.')[:-1]) + '.mat')
beta_map = io.loadmat(beta_path)['pChange']
rand_arr = np.random.rand(cover.size[0], cover.size[1])
inf_map = rand_arr < (beta_map / 2.)
images[1, np.logical_and(
images[0, :, :, 0] != 255, inf_map), 0] += 1
inf_map[:, :] = rand_arr > 1 - (beta_map / 2.)
images[1, np.logical_and(
images[0, :, :, 0] != 0, inf_map), 0] -= 1
else:
stego_path = os.path.join(self.stego_dir,
self.cover_list[idx])
print("stego path", stego_path)
images[1, :, :, 0] = Image.open(stego_path)
samples = {'images': images, 'labels': labels}
print(images[0].shape)
if self.transform:
samples = self.transform(samples)
return samples
我打印cover
时,所有图像的尺寸均为256 * 256:
@@@ <PIL.PpmImagePlugin.PpmImageFile image mode=L size=256x256 at 0x7F96631B8748>
当我使用以下配置创建网络模型时:
class MyNet(nn.Module):
def __init__(self, with_bn=False, threshold=3):
super(YeNet, self).__init__()
self.with_bn = with_bn
self.preprocessing = SRM_conv2d(1, 0)
self.TLU = nn.Hardtanh(-threshold, threshold, True)
...
def forward(self, x):
x = x.float()
print("########", x.shape)
x = self.preprocessing(x)
...
运行我的网络,出现此错误:
Traceback (most recent call last):
File "main.py", line 169, in <module>
train(epoch)
File "main.py", line 121, in train
outputs = net(images)
File "/home/emad/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/emad/myworks/stego/codes/analysis/YeNet-Pytorch/YeNet.py", line 110, in forward
x = self.preprocessing(x)
File "/home/emad/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/emad/myworks/stego/codes/analysis/YeNet-Pytorch/YeNet.py", line 53, in forward
self.dilation, self.groups
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 30 1 5 5 0, but got 5-dimensional input of size [4, 2, 1, 256, 256] instead
这是我的网络通话:
for batch_idx, data in enumerate(train_loader):
images, labels = Variable(
data['images']), Variable(data['labels'])
if args.cuda:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = net(images)
这是self.preprocessing
函数:
class SRM_conv2d(nn.Module):
def __init__(self, stride=1, padding=0):
super(SRM_conv2d, self).__init__()
self.in_channels = 1
self.out_channels = 30
self.kernel_size = (5, 5)
if isinstance(stride, int):
self.stride = (stride, stride)
else:
self.stride = stride
if isinstance(padding, int):
self.padding = (padding, padding)
else:
self.padding = padding
self.dilation = (1, 1)
self.transpose = False
self.output_padding = (0,)
self.groups = 1
print(1)
self.weight = Parameter(torch.Tensor(30, 1, 5, 5),
requires_grad=True)
print(2)
self.bias = Parameter(torch.Tensor(30),
requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.numpy()[:] = SRM_npy
self.bias.data.zero_()
def forward(self, input):
# print(
# "3333", input, self.weight, self.bias,
# self.stride, self.padding,
# self.dilation, self.groups
# )
print("###")
return F.conv2d(
input, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups
)
print("$$$")
我想为我的网络启用批量输入。而且,当我使用batch_size=1
运行代码时,也会遇到相同的错误。
其他信息: 这是我使用的三个变换:
class ToTensor(object):
def __call__(self, samples):
images, labels = samples['images'], samples['labels']
print("EEE", images.shape)
images = images.transpose((0, 3, 1, 2))
# images = (images.transpose((0,3,1,2)).astype('float32') / 127.5) - 1.
return {'images': torch.from_numpy(images),
'labels': torch.from_numpy(labels).long()}
class RandomRot(object):
def __call__(self, samples):
images = samples['images']
rot = random.randint(0, 3)
return {'images': np.rot90(images, rot, axes=[1, 2]).copy(),
'labels': samples['labels']}
class RandomFlip(object):
def __call__(self, samples):
if random.random() < 0.5:
images = samples['images']
return {'images': np.flip(images, axis=2).copy(),
'labels': samples['labels']}
else:
return samples