在keras训练期间使用fit_generator的问题

时间:2019-07-10 13:36:17

标签: python keras neural-network

我正在处理非常大的文本数据集。我考虑过使用model.fit_generator方法而不是简单的model.fit,所以我尝试使用此生成器:

def TrainGenerator(inp, out):
  for i,o in zip(inp, out):
    yield i,o

当我尝试使用以下方法进行训练时使用它:

#inp_train, out_train are lists of sequences padded to 50 tokens
model.fit_generator(generator = TrainGenerator(inp_train,out_train),
                    steps_per_epoch = BATCH_SIZE * 100,
                    epochs = 20,
                    use_multiprocessing = True)

我得到:

ValueError: Error when checking input: expected embedding_input to have shape (50,) but got array with shape (1,)

现在,我尝试使用简单的model.fit方法,并且工作正常。所以,我认为我的问题出在发电机上,但是由于我是发电机的新手,所以我不知道如何解决它。完整的模型摘要为:

Layer (type)                 Output Shape            
===========================================
Embedding (Embedding)      (None, 50, 400)           
___________________________________________
Bi_LSTM_1 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Bi_LSTM_2 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Output (Dense)             (None, 50, 153)           
===========================================

编辑1

第一个评论以某种方式触发了我。我意识到我误解了发电机的工作原理。我的生成器的输出是形状为50的列表,而不是形状为50的N个列表的列表。因此,我研究了keras文档并发现了this。因此,我更改了工作方式,这是该类作为生成器的工作方式:

class BatchGenerator(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return batch_x, to_categorical(batch_y,num_labels)

其中to_categorical是函数:

def to_categorical(sequences, categories):
    cat_sequences = []
    for s in sequences:
        cats = []
        for item in s:
            cats.append(np.zeros(categories))
            cats[-1][item] = 1.0
        cat_sequences.append(cats)
    return np.array(cat_sequences)

因此,我现在注意到的是网络的良好性能提升,每个时代现在持续了一半。发生这种情况是因为我没有更多的可用RAM,因为现在我没有将所有数据集都加载到内存中?

0 个答案:

没有答案