使用PyTorch加载FITS图像

时间:2018-05-08 10:20:08

标签: python pytorch astropy fits

我正在尝试使用PyTorch创建CNN,但我的图片需要从FITS格式导入,而不是传统的.png或.jpeg等。

有没有办法使用torch.utils.data.DataLoader轻松完成此操作,或者在源代码中是否有一个位置可以放入一个子句来加载时处理FITS文件?

我查看了文档,我发现最相关的东西是ToPILImage变换器,它将张量或ndarray转换为PIL图像。

目前我正在使用图片加载程序,如下所示:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

Astropy:http://www.astropy.org/

Pytorch:https://pytorch.org/

torch.utils:https://pytorch.org/docs/master/data.html

更新:也许使用torchvision.datasets.DatasetFolder而不是DataLoader,在我自己的FITS处理程序中插入会起作用吗?

尝试使用此类时,我收到以下错误:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

此时,torchvision是否真正支持DatasetFolder?

3 个答案:

答案 0 :(得分:2)

通过阅读文档和代码的某些组合,我不认为你一定想要使用ImageFolder,因为它对FITS一无所知。

相反,您应该尝试使用更通用的DataSetFolder类(实际上是ImageFolder的父类)。你可以传递一个它应该处理的扩展列表(即['.fits']和一个" loader"函数,它接受一个FITS文件,似乎应该返回PIL.Image

您甚至可以按照ImageFolder的示例创建自己的子类。 E.g。

class FitsFolder(DatasetFolder):

    EXTENSIONS = ['.fits']

    def __init__(self, root, transform=None, target_transform=None,
                 loader=None):
        if loader is None:
            loader = self.__fits_loader

        super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                         transform=transform,
                                         target_transform=target_transform)

    @staticmethod
    def __fits_loader(filename):
        data = fits.getdata(filename)
        return Image.fromarray(data)

__fits_loader的具体细节可能取决于您的FITS文件的详细信息。这个基本示例只使用高级fits.getdata()函数,该函数返回FITS文件中的第一个图像数组(某些FITS文件可能包含许多具有许多图像的扩展,或者具有表等)。那部分将取决于你。

答案 1 :(得分:0)

您可以使用以下方法将FITS图像导出为pyplot.imsave()支持的任何格式:

mLeScanCallback.onLeScan()

答案 2 :(得分:0)

几周前,我遇到了与@ user8188120相同的问题。从文件夹结构中读取标签时,使用@Iguananaut的答案非常有用。如果有人偶然发现此问题并需要从csv文件中读取内容,那么这也可能会起作用:

labels = []
transform = transforms.Compose([
    # here go your transforms
    ])


class MyFitsDataset(data.Dataset):
    def __init__(self, csv_path):
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # the rest contain the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1:])  # for multi-label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])  # for single-label
        labels.append(self.label_arr)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]

        data = pyfits.open(single_image_name, axes=2)
        data = data[0].data.astype('float32')
        data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)

        img = transform(data)

        # Get label(class) of the image based on the pandas column
        single_image_label = self.label_arr[index]

        return (img, single_image_label)

    def __len__(self):
        return self.data_len

这也避免了使用DatasetFolder类,该类在最新版本的PyTorch中仍然不可用。我希望这对某人有帮助。