如何在张量流中批量加载图像?

时间:2019-06-10 07:52:29

标签: python image tensorflow

我一直在使用pytorch,并且是tensorflow的新手。 在pytorch中,我们创建加载器类,该类将返回数据集的单个元组,然后将该类的实例传递给torch.data.Dataloader并定义batch-size和worker。 我们如何在张量流中执行相同的过程?

class Dataset_Tuple_Loader(data.Dataset):


def __init__(self,img_root, csv_label_address, transforms,img_source ='local'):
    add = img_root
    if(img_source=='network'):
        add,csv_label_address,status = DataDownload(img_root,csv_label_address)
    self.imgs_address_list = glob.glob(add + '/*')
    labels_csv = pd.read_csv(csv_label_address)
    labels_csv.set_index('image_name', inplace=True)
    self.labels_csv = labels_csv
    self.transforms = transforms


def __len__(self):
    return len(self.imgs_address_list)

def get_filename(self,filename):
    base = os.path.basename(filename)
    return (os.path.splitext(base)[0])

def __getitem__(self,index):
    img_add = self.imgs_address_list[index]
    img = Image.open(img_add)
    img = self.transforms(np.asarray(img))
    filename = self.get_filename(img_add)
    label = int(self.labels_csv.loc[filename][0])
    return img,label


params = {'batch_size': 2,
      'shuffle': True,
      'num_workers': 1}

csv = 'label_address'
image_root = 'img_add'
just_set = Dataset('local',image_root,csv)


generator = data.DataLoader(just_set,**params)


for tup in generator:
    img = tup[0]
    lb = tup[1]
    print(img.shape,' ',lb)

0 个答案:

没有答案