如何创建一个批处理生成器,以训练具有不同长度且批处理大小> 1的序列的Keras模型?

时间:2020-09-03 17:43:52

标签: python keras nlp lstm

在我的训练数据中,我均匀分布了长度为​​1,2,3的文本序列。 我的目标数据是一个代表一个单词的热编码矢量。

示例数据

 X = [[1,4,3],
      [1,4],
      [1],...]

 y = [0 0 0 0 1 0 0,
      0 0 1 0 0 0 0,
      .............]

基于该线程https://datascience.stackexchange.com/questions/48796/how-to-feed-lstm-with-different-input-array-sizes的答案,我看到有两种方法可以训练长度可变的模型,而无需使用填充。

一种方法是训练批次大小为1的模型,类似于上述帖子中的解决方案,我也将在此处发布。

    class MyBatchGenerator(Sequence):
        'Generates data for Keras'
        def __init__(self, X, y, batch_size=1, shuffle=True):
            'Initialization'
            self.X = X
            self.y = y
            self.batch_size = batch_size
            self.shuffle = shuffle
            self.on_epoch_end()

        def __len__(self):
            'Denotes the number of batches per epoch'
            return int(np.floor(len(self.y)/self.batch_size))

        def __getitem__(self, index):
            return self.__data_generation(index)

        def on_epoch_end(self):
            'Shuffles indexes after each epoch'
            self.indexes = np.arange(len(self.y))
            if self.shuffle == True:
                np.random.shuffle(self.indexes)

        def __data_generation(self, index):
            Xb = np.empty((self.batch_size, *X[index].shape))
            yb = np.empty((self.batch_size, *y[index].shape))
            # naively use the same sample over and over again
            for s in range(0, self.batch_size):
                Xb[s] = X[index]
                yb[s] = y[index]
            return Xb, yb

model = Sequential()
model.add(Embedding(vocabulary_size, 64 , input_length=None))
model.add(LSTM(50,return_sequences=True))
model.add(LSTM(50))
model.add(Dense(50,activation='relu'))
model.add(Dense(vocabulary_size, activation='softmax'))
# compiling the network
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit_generator(MyBatchGenerator(X_train, y_train, batch_size=1), epochs=100)

如何修改MyBatchGenerator使其返回batch_size长度与batch_size > 1相同的序列?

0 个答案:

没有答案