我无法理解这行代码的作用?

时间:2018-11-21 12:48:46

标签: python scipy

这部分课程我不明白这段代码的作用:

for file in os.listdir(path):
    if(os.path.isfile(os.path.join(path,file)) and select in file): 
        temp = scipy.io.loadmat(os.path.join(path,file))
        temp = {k:v for k, v in temp.items() if k[0] != '_'}
        for i  in range(len(temp[patch_type+"_patches"])):
            self.tensors.append(temp[patch_type+"_patches"][i])
            self.labels.append(temp[patch_type+"_labels"][0][i])

self.tensors = np.array(self.tensors)
self.labels = np.array(self.labels)

尤其是这一行:

temp = {k:v for k, v in temp.items() if k[0] != '_'}

整个课程如下:

class Datasets(Dataset):
    def __init__(self,path,train,transform=None):
        if(train):
            select ="Training"
            patch_type = "train"
        else:
            select = "Testing"
            patch_type = "testing"

        self.tensors = []
        self.labels = []
        self.transform = transform


        for file in os.listdir(path):
            if(os.path.isfile(os.path.join(path,file)) and select in file): 

                temp = scipy.io.loadmat(os.path.join(path,file))
                temp = {k:v for k, v in temp.items() if k[0] != '_'}
                for i  in range(len(temp[patch_type+"_patches"])):
                    self.tensors.append(temp[patch_type+"_patches"][i])
                    self.labels.append(temp[patch_type+"_labels"][0][i])

        self.tensors = np.array(self.tensors)
        self.labels = np.array(self.labels)

    def __len__(self):
        try:
            if len(self.tensors) != len(self.labels):
                raise Exception("Lengths of the tensor and labels list are not the same")
        except Exception as e:
            print(e.args[0])
        return len(self.tensors)

    def __getitem__(self,idx):
        sample = (self.tensors[idx],self.labels[idx])
       # print(self.labels)
        sample = (torch.from_numpy(self.tensors[idx]),torch.from_numpy(np.array(self.labels[idx])).long())
        return sample
    #tuple containing the image patch and its corresponding label

2 个答案:

答案 0 :(得分:2)

这是dict comprehension;在这种特殊情况下,它将根据现有字典dict创建一个新的temp,但仅针对键k并非以下划线开头的项目。该检查由if ...部分执行。

等效于

new = {}
for k, v in temp.items():
    if key[0] != '_':
        new[k] = value
temp = new

或者,稍有不同:

new = {}
for key, value in temp.items():
    if not key.startswith('_'):
        new[key] = value
temp = new

您会发现它看起来像是一行更好,因为它避免了一个临时字典(new;在幕后,它仍然会创建一个无名的临时字典)。

答案 1 :(得分:0)

它正在从加载的MATLAB文件中滤除带下划线前缀的变量。函数scipy.io.loadmatscipy documentation返回一个字典,该字典包含已加载文件中的变量名作为键,矩阵作为值。您引用的代码行是dictionary comprehension,它会克隆字典减去未通过条件检查的变量。

更新

这里发生的大致是这样:

  1. 以哈希图(字典)的形式加载MATLAB文件(在代码中为file),其中键是文件中的变量名,值是矩阵,并分配给temp
  2. li>
  3. 遍历那些键/值对,并删除下划线前缀的对,然后将 迭代的结果重新分配给temp
  4. 利润