是否有类似Dataset.batch()的函数,但是对于所有具有不同大小的张量?

时间:2019-08-02 15:40:01

标签: python tensorflow tensorflow-datasets batching

我正在使用tf.data.Dataset.from_generator()函数为具有音频wav_file,音频wav_file的长度,脚本和transcript_len的ASR创建数据集。对于ML模型,我需要音频wav_file和length进行零填充,因此我已经使用.padded_batch()了。现在,除了.batch()之外,我还需要其他东西,因为这需要张量具有相同的形状,但没有零填充以批处理我的数据集。

我想使用CTC损失函数tf.nn.ctc_loss_v2,该函数不需要将transcript和transcript_len张量填充零,而是进行批处理。是否可以批处理具有不同形状的张量的数据集?


def generate_values():
    for _, row in df.iterrows():
       yield row.wav_filename, row.transcript, len(row.transcript) 

def entry_to_features(wav_filename, transcript, transcript_len):
    features, features_len = audiofile_to_features(wav_filename)
    return features, features_len, transcript, transcript_len

def batch_fn(features, features_len, transcripts, transcript_len):        
    features = tf.data.Dataset.zip((features, features_len))
    features = features.padded_batch(batch_size,
                         padded_shapes=([None, Config.n_input], []))
    trans=tf.data.Dataset.zip((transcripts, 
                     transcript_len)).batch(batch_size) ###PROBLEM: 
                     #### ONLY WORKING WITH BATCH_SIZE=1
    return tf.data.Dataset.zip((features, trans)) 

dataset = tf.data.Dataset.from_generator(generate_values,
                         output_types=(tf.string,tf.int64, tf.int64))
dataset= dataset.map(entry_to_features)
dataset= dataset.window(batch_size, drop_remainder=True)
dataset= dataset.flat_map(batch_fn)

InvalidArgumentError(请参见上面的回溯):不能在组件0中批量处理具有不同形状的张量。第一个元素的形状为[36],元素2的形状为[34]

1 个答案:

答案 0 :(得分:1)

如果您想训练seq2seq模型并使用features, transcript作为训练示例,dataset.window不是您要使用的模型。

dataset = tf.data.Dataset.from_generator(generate_values,
                         output_types=(tf.string, tf.int64, tf.int64))
dataset = dataset.map(entry_to_features)
dataset = dataset.padded_batch(batch_size, padded_shapes=([None, Config.n_input], [], [None], []))

之后,您可以按以下方式使用数据集:

for features, feature_length, labels, label_length in dataset.take(30): 
    logits, logit_length = model(features, feature_length)
    loss = tf.nn.ctc_loss_v2(labels, tf.cast(logits, tf.float32), 
                             label_length, logit_length, logits_time_major=False)