我正在尝试使用PyTorch创建CNN,但我的图片需要从FITS格式导入,而不是传统的.png或.jpeg等。
有没有办法使用torch.utils.data.DataLoader轻松完成此操作,或者在源代码中是否有一个位置可以放入一个子句来加载时处理FITS文件?
我查看了文档,我发现最相关的东西是ToPILImage变换器,它将张量或ndarray转换为PIL图像。
目前我正在使用图片加载程序,如下所示:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision
batch_size = 4
transform = transforms.Compose(
[transforms.Resize((32,32)),
transforms.ToTensor(),
])
trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)
Astropy:http://www.astropy.org/
Pytorch:https://pytorch.org/
torch.utils:https://pytorch.org/docs/master/data.html
更新:也许使用torchvision.datasets.DatasetFolder而不是DataLoader,在我自己的FITS处理程序中插入会起作用吗?
尝试使用此类时,我收到以下错误:
AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'
此时,torchvision是否真正支持DatasetFolder?
答案 0 :(得分:2)
通过阅读文档和代码的某些组合,我不认为你一定想要使用ImageFolder
,因为它对FITS一无所知。
相反,您应该尝试使用更通用的DataSetFolder
类(实际上是ImageFolder
的父类)。你可以传递一个它应该处理的扩展列表(即['.fits']
和一个" loader"函数,它接受一个FITS文件,似乎应该返回PIL.Image
。
您甚至可以按照ImageFolder
的示例创建自己的子类。 E.g。
class FitsFolder(DatasetFolder):
EXTENSIONS = ['.fits']
def __init__(self, root, transform=None, target_transform=None,
loader=None):
if loader is None:
loader = self.__fits_loader
super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
transform=transform,
target_transform=target_transform)
@staticmethod
def __fits_loader(filename):
data = fits.getdata(filename)
return Image.fromarray(data)
__fits_loader
的具体细节可能取决于您的FITS文件的详细信息。这个基本示例只使用高级fits.getdata()
函数,该函数返回FITS文件中的第一个图像数组(某些FITS文件可能包含许多具有许多图像的扩展,或者具有表等)。那部分将取决于你。
答案 1 :(得分:0)
您可以使用以下方法将FITS图像导出为pyplot.imsave()支持的任何格式:
mLeScanCallback.onLeScan()
答案 2 :(得分:0)
几周前,我遇到了与@ user8188120相同的问题。从文件夹结构中读取标签时,使用@Iguananaut的答案非常有用。如果有人偶然发现此问题并需要从csv文件中读取内容,那么这也可能会起作用:
labels = []
transform = transforms.Compose([
# here go your transforms
])
class MyFitsDataset(data.Dataset):
def __init__(self, csv_path):
# Read the csv file
self.data_info = pd.read_csv(csv_path, header=None)
# First column contains the image paths
self.image_arr = np.asarray(self.data_info.iloc[:, 0])
# the rest contain the labels
self.label_arr = np.asarray(self.data_info.iloc[:, 1:]) # for multi-label
self.label_arr = np.asarray(self.data_info.iloc[:, 1]) # for single-label
labels.append(self.label_arr)
self.data_len = len(self.data_info.index)
def __getitem__(self, index):
single_image_name = self.image_arr[index]
data = pyfits.open(single_image_name, axes=2)
data = data[0].data.astype('float32')
data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)
img = transform(data)
# Get label(class) of the image based on the pandas column
single_image_label = self.label_arr[index]
return (img, single_image_label)
def __len__(self):
return self.data_len
这也避免了使用DatasetFolder
类,该类在最新版本的PyTorch中仍然不可用。我希望这对某人有帮助。