Keras Unet模型中的tf.signal.irfft输出问题

时间:2019-05-16 19:00:12

标签: python tensorflow keras autoencoder

我已将{net3}的Keras Unet从Github转换为1D版本,希望能够处理单声道音频文件。我还通过lambda添加了tf.signal.rfft,最后还添加了tf.signal.irfft。这是基于围绕this进行的一些谷歌搜索。

由于输出形状意外为(NUM_SAMPLES, 0),我从irfft lambda收到了一个错误消息。输入层和期望的形状虽然是((NUM_SAMPLES,1)。

错误:

ValueError: Error when checking target: expected lambda_5 to have shape (1024, 0) but got array with shape (1024, 1)

当我检查模型层时,可以看到最终的conv层的输出形状是我期望的(NUM_SAMPLES, 1),但是lambda输出形状是(NUM_SAMPLES, 0),这基于错误是有意义的。我尝试进行重塑,还尝试了扁平且密实的设计,没有任何锁定。

代码:


def unet(input_shape=(1024, 1), fft=True):
    wave_input = Input(shape=input_shape, name='main_input')

    x = Lambda(lambda v: tf.to_float(tf.spectral.rfft(v)))(wave_input)

    conv1 = Conv1D(
        64,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(x)
    conv1 = Conv1D(
        64,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling1D(2)(conv1)
    conv2 = Conv1D(
        128,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(pool1)
    conv2 = Conv1D(
        128,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling1D(2)(conv2)
    conv3 = Conv1D(
        256,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(pool2)
    conv3 = Conv1D(
        256,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling1D(2)(conv3)
    conv4 = Conv1D(
        512,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(pool3)
    conv4 = Conv1D(
        512,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling1D(2)(drop4)

    conv5 = Conv1D(
        1024,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(pool4)
    conv5 = Conv1D(
        1024,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv1D(
        512,
        2,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(UpSampling1D(2)(drop5))
    merge6 = concatenate([drop4, up6], axis=2)
    conv6 = Conv1D(
        512,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(merge6)
    conv6 = Conv1D(
        512,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv6)

    up7 = Conv1D(
        256,
        2,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(UpSampling1D(2)(conv6))
    merge7 = concatenate([conv3, up7], axis=2)
    conv7 = Conv1D(
        256,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(merge7)
    conv7 = Conv1D(
        256,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv7)

    up8 = Conv1D(
        128,
        2,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(UpSampling1D(2)(conv7))
    merge8 = concatenate([conv2, up8], axis=2)
    conv8 = Conv1D(
        128,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(merge8)
    conv8 = Conv1D(
        128,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv8)

    up9 = Conv1D(
        64,
        2,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(UpSampling1D(2)(conv8))
    merge9 = concatenate([conv1, up9], axis=2)
    conv9 = Conv1D(
        64,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(merge9)
    conv9 = Conv1D(
        64,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv9)
    conv9 = Conv1D(
        2,
        3,
        activation='relu',
        padding='same',
        kernel_initializer='he_normal')(conv9)
    conv10 = Conv1D(1, 1, activation='sigmoid')(conv9)

    x = tf.keras.layers.Lambda(lambda v: tf.to_float(tf.spectral.irfft(tf.cast(v, dtype=tf.complex64))))(conv10)
    # x = tf.keras.layers.Reshape(input_shape)(x)
    # output = tf.keras.layers.Dense(input_shape[0])(x)

    model = tf.keras.models.Model(inputs=[wave_input], outputs=[output])

    model.compile(
        optimizer=Adam(lr=1e-4),
        loss='binary_crossentropy',
        metrics=['accuracy'])

    return model

感谢您的帮助

0 个答案:

没有答案