在keras中处理数据集时进行批预处理

时间:2019-06-12 20:19:10

标签: keras padding

我有一些可变长度数据矩阵及其相关标签的示例,我想用它来训练LSTM网络。我知道我至少应该为每个批次填充数据样本(例如,使用keras.preprocessing.sequence.pad_sequences),并且我成功地用numpy数组填充了网络数据,但是我不知道如何使用TFRecord数据集。

我的TFRecord文件有一个典型的读取代码,如下所示:

featuresDict = {'data': tf.FixedLenSequenceFeature([], dtype=tf.string),
                'dataShape': tf.FixedLenSequenceFeature([], dtype=tf.int64),
                'label': tf.FixedLenSequenceFeature([], dtype=tf.int64)
               }

def parse_tfrecord(example):
    context, features = tf.parse_single_sequence_example(example, sequence_features=featuresDict)   
    label = features['label']
    data_shape = features['dataShape']
    data = tf.decode_raw(features['data'], tf.int64)
    data = tf.reshape(data, data_shape)
    return label, data

def DataGenerator(fileName, numEpochs=None, batchSize=None):    
  dataset = tf.data.TFRecordDataset(fileName, compression_type='GZIP')
  dataset = dataset.map(parse_tfrecord)
  dataset = dataset.batch(batchSize)
  dataset = dataset.repeat(numEpochs)
  return dataset

我可以解析每个示例并生成我的原始数据矩阵和标签。然后,DataGenerator函数定义数据集并设置其批处理和重复功能。然后创建一个DataGenerator对象,并使用它来适合我的模型:

train_data = DataGenerator(fileName='train.gz', numEpochs=epochs, batchSize=batch_size)
model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, ...)

我可以在哪里在代码中添加填充功能?通常,如果我想使用数据集API进行批处理级预处理,该怎么做?

1 个答案:

答案 0 :(得分:0)

一种方法是在写入TFRecords时在预处理期间填充序列。然后,您可以使用与上面相同的代码。

但是我建议 padded_batch ,它与Keras序列预处理类似。 如果尺寸是已知的(papped_shapes是某个常数),则将序列填充到该常数。否则,它们被填充到最长的序列。