我有一组训练的9957张图像。训练集的形状为(9957、3、60、80)。 将训练设为模型时是否需要批处理大小? 如果需要,可以将原始形状视为适合于conv2D层的正确形状,还是需要将batchsize添加到input_shape?
X_train.shape
(9957,60,80,3) 从chainer.datasets导入split_dataset_random 从chainer.dataset导入DatasetMixin
import numpy as np
class MyDataset(DatasetMixin):
def __init__(self, X, labels):
super(MyDataset, self).__init__()
self.X_ = X
self.labels_ = labels
self.size_ = X.shape[0]
def __len__(self):
return self.size_
def get_example(self, i):
return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i]
batch_size = 3
label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False,
shuffle=False)
答案 0 :(得分:1)
下面的代码告诉您,您不必自己关心批大小。您只需按照链接器教程中的说明使用DatsetMixin
和SerialIterator
。
from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np
NUM_IMAGES = 9957
NUM_CHANNELS = 3 # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80
NUM_CLASSES = 10
BATCH_SIZE = 32
TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))
images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))
class MyDataset(DatasetMixin):
def __init__(self, images_, labels_):
# note: input arg.'s tailing underscore is just to avoid shadowing
super(MyDataset, self).__init__()
self.images_ = images_
self.labels_ = labels_
self.size_ = len(labels_)
def __len__(self):
return self.size_
def get_example(self, i):
return self.images_[i, ...], self.labels_[i]
dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)
###############################################################################
"""This block is just for the confirmation.
.. note: NOT recommended to call :func:`concat_examples` in your code.
Use :class:`chainer.updaters.StandardUpdater` instead.
"""
from chainer.dataset import concat_examples
batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))
输出
batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)
应注意,chainer.dataset.concat_example
有点棘手。通常,如果您使用StandardUpdater
隐藏了本机功能chainer.dataset.concat_example
,则用户不会注意此功能。
由于链接器是根据Trainer
,(Standard)Updater
,某些Optimizer
,(Serial)Iterator
和Dataset(Mixin)
的方案设计的,如果您不遵循此方案,您必须深入研究chainer
源代码。