代码如下:
class Dataset(Dataset):
# Constructor
def __init__(self, csv_file, data_dir, transform=None):
# Image directory
self.data_dir=data_dir
# The transform is going to be used on an image.
self.transform = transform
# Load the CSV file that contians image info
self.data_name= pd.read_csv(csv_file)
# Number of images in dataset
self.len=self.data_name.shape[0]
# Get the length
def __len__(self):
return self.len
# Getter
def __getitem__(self, idx):
# Image file path
img_name=self.data_dir + self.data_name.iloc[idx, 2]
# Open image file
image = Image.open(img_name)
# The class label for the image
y = self.data_name.iloc[idx, 3]
# If there is any transform method, apply it onto the image
if self.transform:
image = self.transform(image)
return image, y
#=================
# Create the dataset objects
train_dataset = Dataset(csv_file=train_csv_file
, data_dir='/resources/data/training_data_pytorch/')
我想回答以下问题:
问题:打印出三个样本图像,它们的类别构成训练数据。
样本= [53、23、10]
谢谢!