这部分课程我不明白这段代码的作用:
for file in os.listdir(path):
if(os.path.isfile(os.path.join(path,file)) and select in file):
temp = scipy.io.loadmat(os.path.join(path,file))
temp = {k:v for k, v in temp.items() if k[0] != '_'}
for i in range(len(temp[patch_type+"_patches"])):
self.tensors.append(temp[patch_type+"_patches"][i])
self.labels.append(temp[patch_type+"_labels"][0][i])
self.tensors = np.array(self.tensors)
self.labels = np.array(self.labels)
尤其是这一行:
temp = {k:v for k, v in temp.items() if k[0] != '_'}
整个课程如下:
class Datasets(Dataset):
def __init__(self,path,train,transform=None):
if(train):
select ="Training"
patch_type = "train"
else:
select = "Testing"
patch_type = "testing"
self.tensors = []
self.labels = []
self.transform = transform
for file in os.listdir(path):
if(os.path.isfile(os.path.join(path,file)) and select in file):
temp = scipy.io.loadmat(os.path.join(path,file))
temp = {k:v for k, v in temp.items() if k[0] != '_'}
for i in range(len(temp[patch_type+"_patches"])):
self.tensors.append(temp[patch_type+"_patches"][i])
self.labels.append(temp[patch_type+"_labels"][0][i])
self.tensors = np.array(self.tensors)
self.labels = np.array(self.labels)
def __len__(self):
try:
if len(self.tensors) != len(self.labels):
raise Exception("Lengths of the tensor and labels list are not the same")
except Exception as e:
print(e.args[0])
return len(self.tensors)
def __getitem__(self,idx):
sample = (self.tensors[idx],self.labels[idx])
# print(self.labels)
sample = (torch.from_numpy(self.tensors[idx]),torch.from_numpy(np.array(self.labels[idx])).long())
return sample
#tuple containing the image patch and its corresponding label
答案 0 :(得分:2)
这是dict comprehension;在这种特殊情况下,它将根据现有字典dict
创建一个新的temp
,但仅针对键k
并非以下划线开头的项目。该检查由if ...
部分执行。
等效于
new = {}
for k, v in temp.items():
if key[0] != '_':
new[k] = value
temp = new
或者,稍有不同:
new = {}
for key, value in temp.items():
if not key.startswith('_'):
new[key] = value
temp = new
您会发现它看起来像是一行更好,因为它避免了一个临时字典(new
;在幕后,它仍然会创建一个无名的临时字典)。
答案 1 :(得分:0)
它正在从加载的MATLAB文件中滤除带下划线前缀的变量。函数scipy.io.loadmat
从scipy documentation返回一个字典,该字典包含已加载文件中的变量名作为键,矩阵作为值。您引用的代码行是dictionary comprehension,它会克隆字典减去未通过条件检查的变量。
这里发生的大致是这样:
file
),其中键是文件中的变量名,值是矩阵,并分配给temp
。temp
。