pytorch中的unsqueez_和keras中的epxand_dim有什么区别,使用后输出的形状是什么?

时间:2019-04-29 20:18:46

标签: python keras pytorch

我是keras的初学者,我有一个pytorch代码,需要将其更改为keras,但我无法理解其中的一部分。特别是我在输出形状的尺寸上有问题。 image的形状是(:,3,32,32),image的第一维是批次的大小。现在,我的问题是:这条线是做什么的,输出形状是什么?

    image_yuv_ch = image[:, channel, :, :].unsqueeze_(1)

它在位置1添加了尺寸?输出形状是什么? 过滤器的大小为(64,8,8),然后我们得到filters.unsqueez_(1),这意味着filters的新形状是(64,1,8,8)吗? 这条线是做什么的? image_conv = F.conv2d(image_yuv_ch, filters, stride=8)与keras中的conv2d相同,输出张量的形状是什么?我还不明白该怎么办?我知道它试图以新形状显示张量,但是在下面的代码中,我无法理解每个unsqueez_permuteview之后的输出形状。您能告诉我每条线的输出形状是什么吗?预先谢谢你。

import torch.nn.functional as F
def apply_conv(self, image, filter_type: str):



        if filter_type == 'dct':
            filters = self.dct_conv_weights
        elif filter_type == 'idct':
            filters = self.idct_conv_weights
        else:
            raise('Unknown filter_type value.')

        image_conv_channels = []
        for channel in range(image.shape[1]):
            image_yuv_ch = image[:, channel, :, :].unsqueeze_(1)
            image_conv = F.conv2d(image_yuv_ch, filters, stride=8)
            image_conv = image_conv.permute(0, 2, 3, 1)
            image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1], image_conv.shape[2], 8, 8)
            image_conv = image_conv.permute(0, 1, 3, 2, 4)
            image_conv = image_conv.contiguous().view(image_conv.shape[0],
                                                  image_conv.shape[1]*image_conv.shape[2],
                                                  image_conv.shape[3]*image_conv.shape[4])

            image_conv.unsqueeze_(1)

            # image_conv = F.conv2d()
            image_conv_channels.append(image_conv)

        image_conv_stacked = torch.cat(image_conv_channels, dim=1)

        return image_conv_stacked

1 个答案:

答案 0 :(得分:1)

您似乎是Keras用户或Tensorflow用户,并试图学习Pytorch。 您应该转到Pytorch文档的website,以了解有关每个操作的更多信息。

  • unsqueeze用于将暗淡扩展张量1。 unsqueeze_()中的下划线表示这是in-place函数。
  • view()在喀拉拉语中可以理解为.reshape()
  • permute()用于切换张量的多个维度。例如:
x = torch.randn(1,2,3) # shape [1,2,3]
x = torch.permute(2,0,1) # shape [3,1,2]

为了知道每次操作后的张量形状,只需添加print(x.size())。例如:

image_conv = image_conv.permute(0, 2, 3, 1)
print(image_conv.size())

image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1], 
print(image_conv.size())

image_conv.shape[2], 8, 8)
print(image_conv.size())

image_conv = image_conv.permute(0, 1, 3, 2, 4)
print(image_conv.size())

Pytorch和Tensorflow(Keras的后端)之间的最大区别是Pytorch将生成动态图,而不是像Tensorflow那样生成静态图。您定义模型的方式在Pytorch中无法正常工作,因为conv的权重将不会保存在model.parameters()中,而在反向传播期间无法进行优化。

还有一条评论,请检查此link,以了解如何使用Pytorch定义适当的模型:

import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

评论代码:


import torch

x = torch.randn(8, 3, 32, 32)
print(x.shape)
torch.Size([8, 3, 32, 32])
channel = 1
y = x[:, channel, :, :]
print(y.shape)
torch.Size([8, 32, 32])

y = y.unsqueeze_(1)
print(y.shape)
torch.Size([8, 1, 32, 32])

希望这对您有所帮助,并祝您学习愉快!