我正在尝试对某些dicom图像进行一些SVM分类 - 大小(91,109,91))。总共有400张图像,每个对象健康(0)或有正扫描(1)
我编写了一段简单的代码来遍历目录中的所有dicom,获取像素数并将它们传递给numpy数组并展平数组。对于每个dicom文件,我查找图像状态为0或1的csv文件,并将它们添加到numpy数组中。在这个循环结束时,我有一个2d numpy数组,在numpy数组的单行中为每个患者提供像素计数和状态。
import os
import numpy as np
import dicom
import sklearn
import matplotlib.pyplot as plt
from sklearn import decomposition
from sklearn import cross_validation
from sklearn import svm
import csv
import re
dirName = '/home/nm/MachineLearning/DaTSCAN/PPMI/'
# Name of csv file
results_dat='PPMIdatabase.csv'
results_path=os.path.join("/",dirName,results_dat)
# make an empty array that we will populate with dicom image array values
data = []
for filename in os.listdir(dirName):
dicom_file = os.path.join("/",dirName,filename)
if os.path.isfile(dicom_file) and filename.endswith(".dcm"):
try:
# check for study in csv file to get diagnosis
#Get study number from dicom string
study_id = int(re.search(r'\d+', filename).group())
with open(results_path, 'r') as file:
reader = csv.reader(file)
search_group = [line[1] for line in reader if line[0] == str(study_id)]
group = str(search_group[0])
#HC 0 # PD 1
if group == 'HC':
group_id = 0
else:
group_id = 1
ds = dicom.read_file(dicom_file)
img = ds.pixel_array
a = np.reshape(img,[img.size,1],'C')
# Add group_id to a
a = np.insert(a,0,group_id)
data.append(a)
except InvalidDicomError:
print("File %s cannot be opened by dicom.read_file" %(filename))
#make python list to numpy array
full_data = np.array(data)
# Want to predict Y
Y = full_data[:,0] # first row of array is classification status 0 or 1
# Image data
X = full_data[:, 1:]
然后,我想运行交叉验证来评估估算工具的效果(使用scikitlearn
)
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X,Y, test_size=0.4,random_state=0)
clf = svm.SVC(kernel='linear', C=1).fit(X_train, y_train)
clf.score(X_test, y_test)
但是我遇到了以下错误,表明存在内存问题
Traceback (most recent call last):
File "cross_validation.py", line 56, in <module>
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X,Y, test_size=0.4,random_state=0)
File "/home/nm/.local/lib/python2.7/site-packages/sklearn/cross_validation.py", line 1919, in train_test_split
safe_indexing(a, test)) for a in arrays))
File "/home/nm/.local/lib/python2.7/site-packages/sklearn/cross_validation.py", line 1919, in <genexpr>
safe_indexing(a, test)) for a in arrays))
File "/home/nm/.local/lib/python2.7/site-packages/sklearn/utils/__init__.py", line 163, in safe_indexing
return X.take(indices, axis=0)
MemoryError
我的numpy数组太大了。我怎么能绕过这个?