如何增加火炬张量的通道数?

时间:2020-07-31 23:03:51

标签: python machine-learning pytorch tensor

我有一个[N, 2, H, W]格式的pytorch张量,其中2是通道数。但是,我正在使用的模型(预训练的resnet18)要求我具有尺寸[N, 3, H, W]。如何将频道数量从2增加到3?

1 个答案:

答案 0 :(得分:0)

将2通道图像作为灰度保存到磁盘,然后执行以下操作:

import torch
from PIL import Image
from torchvision.models import resnet18

from torchvision import transforms

transform = transforms.Compose([            
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)])

img= Image.open('pic.jpg').convert('RGB')
tensor= transform(img)
tensor= torch.unsqueeze(tensor, 0).float().cuda()

resnet_18_model= resnet18(pretrained= True).cuda() # resnet18()
resnet_18_model.eval()
output= resnet_18_model(tensor)

output= torch.argmax(output)
print('Class Number: ', output.item())