收敛到空白索引

时间:2020-02-03 12:08:30

标签: python tensorflow keras

我根据CTC损失为Speech2text任务编写了代码。如您所知,我们必须为语音的无声部分定义一个空白索引。我的模型很简单,如下所示:

def get_model(input_dim, output_dim,
                    rnn_units=30) -> Model:
    with tf.device('/cpu:0'):
        input_tensor = layers.Input([None, input_dim])

        x = layers.Lambda(k.expand_dims,
                          arguments=dict(axis=-1))(input_tensor)
        x = layers.Conv2D(filters=32,
                          kernel_size=[11, 4],
                          strides=[2, 2],
                          padding='same',
                          use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

        x = layers.Conv2D(filters=32,
                          kernel_size=[11, 2],
                          strides=[1, 2],
                          padding='same',
                          use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Reshape([-1, input_dim // 4 * 32])(x)

        recurrent = layers.LSTM(units=rnn_units,
                               activation='tanh',
                               recurrent_activation='sigmoid',
                               use_bias=True,
                               return_sequences=True,
                               )
        x = layers.Bidirectional(recurrent,
                                 merge_mode='concat')(x)

        x = layers.TimeDistributed(layers.Dense(units=rnn_units * 2))(x)
        x = layers.ReLU()(x)
        output_tensor = layers.Dense(units=output_dim)(x)
        #        output_tensor = layers.Lambda(lambda y: softmax(y, axis=-1))(output_tensor)

        model = Model(input_tensor, output_tensor)
        return model

然后我按如下方式使用CTC损失功能

def get_loss() -> Callable:
    def get_length(tensor):
        lengths = tf.math.reduce_sum(tf.ones_like(tensor), 1)
        return tf.cast(lengths, tf.int32)

    def ctc_loss(labels, logits):
        label_length = get_length(labels)
        logit_length = get_length(tf.math.reduce_max(logits, 2))
        return tf.reduce_mean(tf.nn.ctc_loss(labels, logits, label_length, logit_length,
               logits_time_major=False, blank_index=-1))

    return ctc_loss

然后按以下方式优化损失函数:

y = Input(name='y', shape=[None], dtype='int32')
model.compile(RMSprop(1e-4), loss=get_loss(), target_tensors=[y])
model.fit(dataset, validation_data=dev_dataset, callbacks=[checkpointer], **kwargs)

经过一段时间(无论学习率如何),我得到以下结果:

[33 33 33 33 33 33 33 33 33 33 33 33 33 33 33]

数字33等于空白索引。每当我训练网络时,任何输入的输出都是!问题是什么? :-\

0 个答案:

没有答案