Keras:如果数据大小不能被batch_size整除怎么办?

时间:2016-06-22 17:05:00

标签: theano deep-learning keras

我是Keras的新手,刚刚开始研究一些例子。我正在处理以下问题:我有4032个样本并且使用大约650个样本用于拟合或基本上训练状态,然后使用其余的来测试模型。问题是我一直收到以下错误:

nil

我理解为什么会收到此错误,我的问题是,如果我的数据大小不能被Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size. 整除怎么办?我曾经和Deeplearning4j LSTM一起工作,没有必要处理这个问题。无论如何都可以解决这个问题吗?

由于

1 个答案:

答案 0 :(得分:3)

最简单的解决方案是使用fit_generator而不是fit。我编写了一个简单的dataloader类,可以继承它来做更复杂的事情。它会看起来像这样,get_next_batch_data被重新定义为你的数据,包括扩充等等。

class BatchedLoader():
    def __init__(self):
        self.possible_indices = [0,1,2,...N] #(say N = 33)
        self.cur_it = 0
        self.cur_epoch = 0

    def get_batch_indices(self):
        batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
        # If len(batch_indices) < batchsize, the you've reached the end
        # In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
        # And add remaining K = batchsize - len(batch_indices) to batch_indices


    def get_next_batch_data(self):
        # batch_indices = self.get_batch_indices()
        # The data points corresponding to those indices will be your next batch data