链接器CNN输入形状的批量大小

时间:2019-01-07 07:19:17

标签: python-3.x conv-neural-network chainer

我有一组训练的9957张图像。训练集的形状为(9957、3、60、80)。 将训练设为模型时是否需要批处理大小? 如果需要,可以将原始形状视为适合于conv2D层的正确形状,还是需要将batchsize添加到input_shape?

X_train.shape

(9957,60,80,3)     从chainer.datasets导入split_dataset_random     从chainer.dataset导入DatasetMixin

import numpy as np


class MyDataset(DatasetMixin):
   def __init__(self, X, labels):
       super(MyDataset, self).__init__()
       self.X_ = X
       self.labels_ = labels
       self.size_ = X.shape[0]

   def __len__(self):
       return self.size_

   def get_example(self, i):
       return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i] 


batch_size = 3

label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False, 
shuffle=False)

1 个答案:

答案 0 :(得分:1)

下面的代码告诉您,您不必自己关心批大小。您只需按照链接器教程中的说明使用DatsetMixinSerialIterator

from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np

NUM_IMAGES = 9957
NUM_CHANNELS = 3  # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80

NUM_CLASSES = 10

BATCH_SIZE = 32

TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))

images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))


class MyDataset(DatasetMixin):
    def __init__(self, images_, labels_):
        # note: input arg.'s tailing underscore is just to avoid shadowing
        super(MyDataset, self).__init__()
        self.images_ = images_
        self.labels_ = labels_
        self.size_ = len(labels_)

    def __len__(self):
        return self.size_

    def get_example(self, i):
        return self.images_[i, ...], self.labels_[i]


dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)

###############################################################################
"""This block is just for the confirmation.

.. note: NOT recommended to call :func:`concat_examples` in your code.
    Use :class:`chainer.updaters.StandardUpdater` instead. 
"""
from chainer.dataset import concat_examples

batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))

输出

batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)

应注意,chainer.dataset.concat_example有点棘手。通常,如果您使用StandardUpdater隐藏了本机功能chainer.dataset.concat_example,则用户不会注意此功能。

由于链接器是根据Trainer(Standard)Updater,某些Optimizer(Serial)IteratorDataset(Mixin)的方案设计的,如果您不遵循此方案,您必须深入研究chainer源代码。