keras model.predict_generator()没有返回正确数量的实例

时间:2018-09-07 09:49:19

标签: python tensorflow keras generator

我已按照以下链接学习了将generator模型的keras用于fit_generator的方法。 https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly 我遇到的一个问题是,当我在某个测试数据生成器上调用model.predict_generator()时,返回值的长度与在生成器中发送的值不同。 我的测试数据长度为229431,并且我使用的batch_size为256,并且当我通过以下方式在__len__类中定义generator函数时:

class DataGenerator(keras.utils.Sequence):
    """A simple generator"""

    def __init__(self, list_IDs, labels, dim, dim_label, batch_size=512, shuffle=True, is_training=True):
        """Initialization"""
        self.list_IDs = list_IDs
        self.labels = labels
        self.dim = dim
        self.dim_label = dim_label
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.is_training = is_training
        self.on_epoch_end()

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        """Generate one batch of data"""
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        list_labels_temp = [self.labels[k] for k in indexes]

        # Generate data
        result = self.__data_generation(list_IDs_temp, list_labels_temp, self.is_training)
        if self.is_training:
            X, y = result
            return X, y
        else:
            # only return X when test
            X = result
            return X

    def on_epoch_end(self):
        """Updates indexes after each epoch"""
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp, list_labels_temp, is_training):
        """Generates data containing batch_size samples"""
        # Initialization
        # X is a list of np.array
        X = np.empty((self.batch_size, *self.dim))
        if is_training:
            # y could have multiple columns
            y = np.empty((self.batch_size, *self.dim_label), dtype=int)

        # Generate data
        for i, (ID, label) in enumerate(zip(list_IDs_temp, list_labels_temp)):
            # Store sample
            X[i,] = np.load(ID)
            if is_training:
                # Store class
                y[i,] = np.load(label)
        if is_training:
            return X, y
        else:
            return X

我的预测值的返回长度为229632。这是predict的代码:

test_generator = DataGenerator(partition, labels, is_training=False, **self.params)
        predict_raw = self.model.predict_generator(generator=test_generator, workers=12, verbose=2)

我认为229632/256 = 897,这是我生成器的长度,当我将__len__的{​​{1}}方法修改为DataGenerator时,我得到229376预测值,229376 / 256 = 896,这是正确的长度数。 但是我传递给生成器的是229431个样本。

我认为在return int(np.ceil(len(self.list_IDs) / self.batch_size))方法中,当在最后一批上运行时,它应仅获取少于256个样本以进行自动测试。但是显然不是这样,那么如何确定模型能够预测正确数量的样本?

1 个答案:

答案 0 :(得分:1)

对于最后一批,在方法foreach (InfoQuery item in InfoList) { if (item == "kitten") { if (!done) { TextView view= new TextView(ApplicationContext); view.LayoutParameters = mainLayout.LayoutParameters; view.TextSize = TypedValue.ApplyDimension(ComplexUnitType.Sp, 3, ApplicationContext.Resources.DisplayMetrics); view.SetPadding((int)TypedValue.ApplyDimension(ComplexUnitType.Dip, 10, ApplicationContext.Resources.DisplayMetrics), 0, (int)TypedValue.ApplyDimension(ComplexUnitType.Dip, 10, ApplicationContext.Resources.DisplayMetrics), 0); view.Text = item.position; layout2.AddView(view,0); done = true; } TextView view2= new TextView(ApplicationContext); view2.LayoutParameters = mainLayout.LayoutParameters; view2.TextSize = TypedValue.ApplyDimension(ComplexUnitType.Sp, 3, ApplicationContext.Resources.DisplayMetrics); ; view2.SetPadding((int)TypedValue.ApplyDimension(ComplexUnitType.Dip, 30, ApplicationContext.Resources.DisplayMetrics), 0, (int)TypedValue.ApplyDimension(ComplexUnitType.Dip, 10, ApplicationContext.Resources.DisplayMetrics), 0); view2.SetTextColor(Color.Black); view2.Text = item.position; layout2.AddView(view2); } } 中计算出的索引大小不正确。为了预测正确的样本数量,索引应定义如下(请参见post):

__getitem__