我在手电筒中有一些代码,需要用Keras编写。 coudlu,请告诉我如何将其从火炬变为喀拉拉邦?如何在keras中定义张量而不是在火炬中定义torch.tensor
?
self.dct_conv_weights = torch.tensor(gen_filters(8, 8, dct_coeff), dtype=torch.float32).to(self.device)
我将完整的代码放在这里,在此代码中它们可以处理彩色图像,但是我的图像是灰色的:
class JpegCompression(nn.Module):
def __init__(self, device, yuv_keep_weights = (25, 9, 9)):
super(JpegCompression, self).__init__()
self.device = device
self.dct_conv_weights = torch.tensor(gen_filters(8, 8, dct_coeff), dtype=torch.float32).to(self.device)
self.dct_conv_weights.unsqueeze_(1)
self.idct_conv_weights = torch.tensor(gen_filters(8, 8, idct_coeff), dtype=torch.float32).to(self.device)
self.idct_conv_weights.unsqueeze_(1)
self.yuv_keep_weighs = yuv_keep_weights
self.keep_coeff_masks = []
self.jpeg_mask = None
# create a new large mask which we can use by slicing for images which are smaller
self.create_mask((1000, 1000))
def create_mask(self, requested_shape):
if self.jpeg_mask is None or requested_shape > self.jpeg_mask.shape[1:]:
self.jpeg_mask = torch.empty((3,) + requested_shape, device=self.device)
for channel, weights_to_keep in enumerate(self.yuv_keep_weighs):
mask = torch.from_numpy(get_jpeg_yuv_filter_mask(requested_shape, 8, weights_to_keep))
self.jpeg_mask[channel] = mask
def get_mask(self, image_shape):
if self.jpeg_mask.shape < image_shape:
self.create_mask(image_shape)
# return the correct slice of it
return self.jpeg_mask[:, :image_shape[1], :image_shape[2]].clone()
def apply_conv(self, image, filter_type: str):
if filter_type == 'dct':
filters = self.dct_conv_weights
elif filter_type == 'idct':
filters = self.idct_conv_weights
else:
raise('Unknown filter_type value.')
image_conv_channels = []
for channel in range(image.shape[1]):
image_yuv_ch = image[:, channel, :, :].unsqueeze_(1)
image_conv = F.conv2d(image_yuv_ch, filters, stride=8)
image_conv = image_conv.permute(0, 2, 3, 1)
image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1], image_conv.shape[2], 8, 8)
image_conv = image_conv.permute(0, 1, 3, 2, 4)
image_conv = image_conv.contiguous().view(image_conv.shape[0],
image_conv.shape[1]*image_conv.shape[2],
image_conv.shape[3]*image_conv.shape[4])
image_conv.unsqueeze_(1)
# image_conv = F.conv2d()
image_conv_channels.append(image_conv)
image_conv_stacked = torch.cat(image_conv_channels, dim=1)
return image_conv_stacked
def forward(self, noised_and_cover):
noised_image = noised_and_cover[0]
# pad the image so that we can do dct on 8x8 blocks
pad_height = (8 - noised_image.shape[2] % 8) % 8
pad_width = (8 - noised_image.shape[3] % 8) % 8
noised_image = nn.ZeroPad2d((0, pad_width, 0, pad_height))(noised_image)
# convert to yuv
image_yuv = torch.empty_like(noised_image)
rgb2yuv(noised_image, image_yuv)
assert image_yuv.shape[2] % 8 == 0
assert image_yuv.shape[3] % 8 == 0
# apply dct
image_dct = self.apply_conv(image_yuv, 'dct')
# get the jpeg-compression mask
mask = self.get_mask(image_dct.shape[1:])
# multiply the dct-ed image with the mask.
image_dct_mask = torch.mul(image_dct, mask)
# apply inverse dct (idct)
image_idct = self.apply_conv(image_dct_mask, 'idct')
# transform from yuv to to rgb
image_ret_padded = torch.empty_like(image_dct)
yuv2rgb(image_idct, image_ret_padded)
# un-pad
noised_and_cover[0] = image_ret_padded[:, :, :image_ret_padded.shape[2]-pad_height, :image_ret_padded.shape[3]-pad_width].clone()
return noised_and_cover
以上代码是用火炬编写的,我应该用Keras编写此代码。请你帮助我好吗?谢谢。