在Python中将mha 3d图像转换为2d图像(2015 brats challenge dataset)

时间:2017-11-08 06:44:45

标签: python arrays tensorflow simpleitk

我想使用SimpleITK或wedpy将3d图像转换为2d图像。 或者我想得到一个三维矩阵,然后我将三维矩阵划分为一些二维矩阵。

import SimpleITK as ITK
import numpy as np
#from medpy.io import load
url=r'G:\path\to\my.mha'
image = ITK.ReadImage(url)
frame_num, width, height = image_array.shape
print(frame_num,width,height)

然后才得到它:155 240 240

但我希望[[1,5,2,3,1 ...],[54,1,3,5 ...],[5,8,9,6 ....] ]

4 个答案:

答案 0 :(得分:1)

只是为了添加Dave Chen的答案,因为目前还不清楚你是否想要获得一组2D SimpleITK图像或numpy数组。以下代码涵盖了所有三个可用选项:

import SimpleITK as sitk
import numpy as np

url = "my_file.mha"

image = sitk.ReadImage(url)

max_index = image.GetDepth() # or image.GetWidth() or image.GetHeight() depending on the axis along which you want to extract

# As list of 2D SimpleITK images
list_of_2D_images = [image[:,:,i] for i in range(max_index)]

# As list of 2D numpy arrays which cannot be modified (no data copied) 
list_of_2D_images_np_view = [sitk.GetArrayViewFromImage(image[:,:,i]) for i in range(max_index)]

# As list of 2D numpy arrays (data copied to numpy array)
list_of_2D_images_np = [sitk.GetArrayFromImage(image[:,:,i]) for i in range(max_index)]

此外,如果您真的想使用网址而非本地文件,我建议您查看SimpleITK notebooks repository中使用的远程下载方法,相关文件为downloaddata.py

答案 1 :(得分:1)

这没什么大不了的。 CT图像原本都是int16类型的数字,所以你不需要处理浮点数..在这种情况下,我们可以认为我们可以轻松地从int16改为uint16 只去除图像中的负值(CT 图像有一些负数作为像素值)。请注意,我们确实需要 uint16uint8 类型,以便 OpenCV 可以处理它......因为我们在 CT 图像数组中有很多值,最好的选择是uint16,这样我们就不会失去太多的精度。 好的,现在你只需要做如下操作:

import SimpleITK as sitk
import numpy as np
import cv2

mha   = sitk.ReadImage('/mha/directory')  #Importing mha file
array = sitk.GetArrayFromImage(mha)       #Converting to array int16 (default)

#Translating each slice to the positive side
for m in range(array.shape[0]):
    array[m] = array[m] + abs(np.min(array[m]))

array = np.around(array, decimals=0)      #remove any float numbers if exists.. probably not
array = np.asarray(array, dtype='uint16') #From int16 to uint16

完成这些步骤后,数组就可以使用 opencv.imwrite 模块保存为 png 图像了:

for image in array:
    cv2.imwrite('/dir/to/save/'+'name_image.png', image)

请注意,默认情况下,SimpleITK 通过轴向视图处理 .mha 文件。我真的不知道如何更改它,因为我以前从未需要它。无论如何,在这种情况下,通过一些搜索,您可以找到一些东西。

答案 2 :(得分:0)

我不确定你想要得到什么。但是,在SimpleITK中从3d图像中提取二维切片很容易。

要获得Z = 100的Z切片,您可以这样做: zslice = image [100]

要获得Y = 100的Y切片: yslice = image [:,100]

X = 100的X切片: xslice = image [:,:,100]

答案 3 :(得分:0)

@ zivy @ Dave Chen 我已经解决了我的问题。事实上,运行此代码将为您提供150张240 * 240张PNG图片。这是我想要的。

# -*- coding:utf-8 -*-
import numpy as np
import subprocess
import random
import progressbar
from glob import glob
from skimage import io

np.random.seed(5) # for reproducibility
progress = progressbar.ProgressBar(widgets=[progressbar.Bar('*', '[', ']'), progressbar.Percentage(), ' '])

class BrainPipeline(object):
    '''
    A class for processing brain scans for one patient
    INPUT:  (1) filepath 'path': path to directory of one patient. Contains following mha files:
            flair, t1, t1c, t2, ground truth (gt)
            (2) bool 'n4itk': True to use n4itk normed t1 scans (defaults to True)
            (3) bool 'n4itk_apply': True to apply and save n4itk filter to t1 and t1c scans for given patient. This will only work if the
    '''
    def __init__(self, path, n4itk = True, n4itk_apply = False):
        self.path = path
        self.n4itk = n4itk
        self.n4itk_apply = n4itk_apply
        self.modes = ['flair', 't1', 't1c', 't2', 'gt']
        # slices=[[flair x 155], [t1], [t1c], [t2], [gt]], 155 per modality
        self.slices_by_mode, n = self.read_scans()
        # [ [slice1 x 5], [slice2 x 5], ..., [slice155 x 5]]
        self.slices_by_slice = n
        self.normed_slices = self.norm_slices()

    def read_scans(self):
        '''
        goes into each modality in patient directory and loads individual scans.
        transforms scans of same slice into strip of 5 images
        '''
        print('Loading scans...')
        slices_by_mode = np.zeros((5, 155, 240, 240))
        slices_by_slice = np.zeros((155, 5, 240, 240))
        flair = glob(self.path + '/*Flair*/*.mha')
        t2 = glob(self.path + '/*_T2*/*.mha')
        gt = glob(self.path + '/*more*/*.mha')
        t1s = glob(self.path + '/**/*T1*.mha')
        t1_n4 = glob(self.path + '/*T1*/*_n.mha')
        t1 = [scan for scan in t1s if scan not in t1_n4]
        scans = [flair[0], t1[0], t1[1], t2[0], gt[0]] # directories to each image (5 total)
        if self.n4itk_apply:
            print('-> Applyling bias correction...')
            for t1_path in t1:
                self.n4itk_norm(t1_path) # normalize files
            scans = [flair[0], t1_n4[0], t1_n4[1], t2[0], gt[0]]
        elif self.n4itk:
            scans = [flair[0], t1_n4[0], t1_n4[1], t2[0], gt[0]]
        for scan_idx in xrange(5):
            # read each image directory, save to self.slices
            slices_by_mode[scan_idx] = io.imread(scans[scan_idx], plugin='simpleitk').astype(float)
        for mode_ix in xrange(slices_by_mode.shape[0]): # modes 1 thru 5
            for slice_ix in xrange(slices_by_mode.shape[1]): # slices 1 thru 155
                slices_by_slice[slice_ix][mode_ix] = slices_by_mode[mode_ix][slice_ix] # reshape by slice
        return slices_by_mode, slices_by_slice

    def norm_slices(self):
        '''
        normalizes each slice in self.slices_by_slice, excluding gt
        subtracts mean and div by std dev for each slice
        clips top and bottom one percent of pixel intensities
        if n4itk == True, will apply n4itk bias correction to T1 and T1c images
        '''
        print('Normalizing slices...')
        normed_slices = np.zeros((155, 5, 240, 240))
        for slice_ix in xrange(155):
            normed_slices[slice_ix][-1] = self.slices_by_slice[slice_ix][-1]
            for mode_ix in xrange(4):
                normed_slices[slice_ix][mode_ix] =  self._normalize(self.slices_by_slice[slice_ix][mode_ix])
        print('Done.')
        return normed_slices

    def _normalize(self, slice):
        '''
        INPUT:  (1) a single slice of any given modality (excluding gt)
                (2) index of modality assoc with slice (0=flair, 1=t1, 2=t1c, 3=t2)
        OUTPUT: normalized slice
        '''
        b, t = np.percentile(slice, (0.5,99.5))
        slice = np.clip(slice, b, t)
        if np.std(slice) == 0:
            return slice
        else:
            return (slice - np.mean(slice)) / np.std(slice)

    def save_patient(self, reg_norm_n4, patient_num):
        '''
        INPUT:  (1) int 'patient_num': unique identifier for each patient
                (2) string 'reg_norm_n4': 'reg' for original images, 'norm' normalized images, 'n4' for n4 normalized images
        OUTPUT: saves png in Norm_PNG directory for normed, Training_PNG for reg
        '''
        print('Saving scans for patient {}...'.format(patient_num))
        progress.currval = 0
        if reg_norm_n4 == 'norm': #saved normed slices
            for slice_ix in progress(xrange(155)): # reshape to strip
                strip = self.normed_slices[slice_ix].reshape(1200, 240)
                if np.max(strip) != 0: # set values < 1
                    strip /= np.max(strip)
                if np.min(strip) <= -1: # set values > -1
                    strip /= abs(np.min(strip))
                # save as patient_slice.png
                io.imsave('Norm_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
        elif reg_norm_n4 == 'reg':
            for slice_ix in progress(xrange(155)):
                strip = self.slices_by_slice[slice_ix].reshape(1200, 240)
                if np.max(strip) != 0:
                    strip /= np.max(strip)
                io.imsave('Training_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)
        else:
            for slice_ix in progress(xrange(155)): # reshape to strip
                strip = self.normed_slices[slice_ix].reshape(1200, 240)
                if np.max(strip) != 0: # set values < 1
                    strip /= np.max(strip)
                if np.min(strip) <= -1: # set values > -1
                    strip /= abs(np.min(strip))
                # save as patient_slice.png
                io.imsave('n4_PNG/{}_{}.png'.format(patient_num, slice_ix), strip)

    def n4itk_norm(self, path, n_dims=3, n_iters='[20,20,10,5]'):
        '''
        INPUT:  (1) filepath 'path': path to mha T1 or T1c file
                (2) directory 'parent_dir': parent directory to mha file
        OUTPUT: writes n4itk normalized image to parent_dir under orig_filename_n.mha
        '''
        output_fn = path[:-4] + '_n.mha'
        # run n4_bias_correction.py path n_dim n_iters output_fn
        subprocess.call('python n4_bias_correction.py ' + path + ' ' + str(n_dims) + ' ' + n_iters + ' ' + output_fn, shell = True)


def save_patient_slices(patients, type):
    '''
    INPUT   (1) list 'patients': paths to any directories of patients to save. for example- glob("Training/HGG/**")
            (2) string 'type': options = reg (non-normalized), norm (normalized, but no bias correction), n4 (bias corrected and normalized)
    saves strips of patient slices to approriate directory (Training_PNG/, Norm_PNG/ or n4_PNG/) as patient-num_slice-num
    '''
    for patient_num, path in enumerate(patients):
        a = BrainPipeline(path)
        a.save_patient(type, patient_num)

def s3_dump(directory, bucket):
    '''
    dump files from a given directory to an s3 bucket
    INPUT   (1) string 'directory': directory containing files to save
            (2) string 'bucket': name od s3 bucket to dump files
    '''
    subprocess.call('aws s3 cp' + ' ' + directory + ' ' + 's3://' + bucket + ' ' + '--recursive')

def save_labels(fns):
    '''
    INPUT list 'fns': filepaths to all labels
    '''
    progress.currval = 0
    for label_idx in progress(range(len(labels))):
        slices = io.imread(labels[label_idx], plugin = 'simpleitk')
        for slice_idx in range(len(slices)):
            io.imsave(r'{}_{}L.png'.format(label_idx, slice_idx), slices[slice_idx])


if __name__ == '__main__':
    url = r'G:\work\deeplearning\BRATS2015_Training\HGG\brats_2013_pat0005_1\VSD.Brain.XX.O.MR_T1.54537\VSD.Brain.XX.O.MR_T1.54537.mha'
    labels = glob(url)
    save_labels(labels)
    # patients = glob('Training/HGG/**')
    # save_patient_slices(patients, 'reg')
    # save_patient_slices(patients, 'norm')
    # save_patient_slices(patients, 'n4')
    # s3_dump('Graveyard/Training_PNG/', 'orig-training-png')