我想用训练数据和验证数据训练我的数据集。 总数据为 2048,训练数据为 1638,验证数据为 410(占总数的 20%)。
这是我的代码
加载数据(组织:总训练数据)
org_x = train_csv.drop(['id', 'digit', 'letter'], axis=1).values
org_x = org_x.reshape(-1, 28, 28, 1)
org_x = org_x/255
org_x = np.array(org_x)
org_x = org_x.reshape(-1, 1, 28, 28)
org_x = torch.Tensor(org_x)
x_test = test_csv.drop(['id','letter'], axis=1).values
x_test = x_test.reshape(-1, 28, 28, 1)
x_test = x_test/255
x_test = np.array(x_test)
x_test = x_test.reshape(-1, 1, 28, 28)
x_test = torch.Tensor(x_test)
y = train_csv['digit']
y = list(y)
print(len(y))
org_y = np.zeros([len(y), 1])
for i in range(len(y)):
org_y[i] = y[i]
org1 = np.array(org_y, dtype=object)
分割数据(组织:总训练数据)
from sklearn.model_selection import train_test_split
x_train, x_valid, y_train, y_valid = train_test_split(
org, org1, test_size=0.2, random_state=42)
变换
transform = transforms.Compose([transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, )) ])
数据集
class kmnistDataset(data.Dataset):
def __init__(self, images, labels=None, transforms=None):
self.x = images
self.y = labels
self.transforms = transforms
def __len__(self):
return (len(self.x))
def __getitem__(self, idx):
data = np.asarray(self.x[idx][0:]).astype(np.uint8)
if self.transforms:
data = self.transforms(data)
if self.y is not None:
return (data, self.y[i])
else:
return data
train_data = kmnistDataset(x_train, y_train, transform)
valid_data = kmnistDataset(x_valid, y_valid, transform)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=16, shuffle = False)
我将跳过模型结构。
培训(这里,我收到了错误信息)
n_epochs = 30
valid_loss_min = np.Inf
for epoch in range(1, n_epochs+1):
train_loss = 0
valid_loss = 0
###################
# train the model #
###################
model.train()
for data in train_loader:
inputs, labels = data[0], data[1]
optimizer.zero_grad()
output = model(inputs)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()*data.size(0)
#####################
# validate the model#
#####################
model.eval()
for data in valid_loader:
inputs, labels = data[0], data[1]
output = model(inputs)
loss = criterion(output, labels)
valid_loss += loss.item()*data.size(0)
train_loss = train_loss/ len(train_loader.dataset)
valid_loss = valid_loss / len(valid_loader.dataset)
print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
epoch, train_loss, valid_loss))
虽然我检查了数据大小,但还是收到了下面的错误消息。
<块引用>索引 2047 超出轴 0 的范围,大小为 1638
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-42-b8783819421f> in <module>
11 ###################
12 model.train()
---> 13 for data in train_loader:
14 inputs, labels = data[0], data[1]
15 optimizer.zero_grad()
/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in_next_data(self)
473 def _next_data(self):
474 index = self._next_index() # may raise StopIteration
--> 475 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
476 if self._pin_memory:
477 data = _utils.pin_memory.pin_memory(data)
/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self,
possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-38-e5c87dd8a7ff> in __getitem__(self, idx)
17
18 if self.y is not None:
---> 19 return (data, self.y[i])
20 else:
21 return data
IndexError: index 2047 is out of bounds for axis 0 with size 1638
你能解释一下为什么以及如何解决吗?
答案 0 :(得分:1)
乍一看,您使用的形状不正确:org_x = org_x.reshape(-1, 28, 28, 1)
。通道轴是第二个(与 TensorFlow 不同),如 (batch_size, channels, height, width)
:
org_x = org_x.reshape(-1, 1, 28, 28)
与 x_test
相同
x_test = x_test.reshape(-1, 1, 28, 28)
此外,您访问的列表越界。您使用 self.y
访问了 i
。在我看来,您应该返回 (data, self.y[idx])
。