我是Keras的新手,刚刚开始研究一些例子。我正在处理以下问题:我有4032个样本并且使用大约650个样本用于拟合或基本上训练状态,然后使用其余的来测试模型。问题是我一直收到以下错误:
nil
我理解为什么会收到此错误,我的问题是,如果我的数据大小不能被Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size.
整除怎么办?我曾经和Deeplearning4j LSTM一起工作,没有必要处理这个问题。无论如何都可以解决这个问题吗?
由于
答案 0 :(得分:3)
最简单的解决方案是使用fit_generator而不是fit。我编写了一个简单的dataloader类,可以继承它来做更复杂的事情。它会看起来像这样,get_next_batch_data被重新定义为你的数据,包括扩充等等。
class BatchedLoader():
def __init__(self):
self.possible_indices = [0,1,2,...N] #(say N = 33)
self.cur_it = 0
self.cur_epoch = 0
def get_batch_indices(self):
batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
# If len(batch_indices) < batchsize, the you've reached the end
# In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
# And add remaining K = batchsize - len(batch_indices) to batch_indices
def get_next_batch_data(self):
# batch_indices = self.get_batch_indices()
# The data points corresponding to those indices will be your next batch data