Keras使用Sequential图层添加数据

时间:2017-03-11 12:50:22

标签: python neural-network deep-learning keras recurrent-neural-network

我在Keras有两个连续模型:

def generator_model():
    model = Sequential()
    model.add(LSTM(512, return_sequences=False, input_shape=(3, 100*100)))
    model.add(Dense(100*100))
    model.add(Reshape((1, 100, 100), input_shape=(100*100,)))
    model.add(Activation('sigmoid'))
    model.compile(optimizer='adadelta', loss='categorical_crossentropy');
    return model

def discriminator_model():
    model = Sequential()
    model.add(LSTM(512, return_sequences=False, input_shape=(4, 100*100)))
    model.add(Dense(100*100))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    model.compile(optimizer='adadelta', loss='categorical_crossentropy');
    return model

我还有一个将这些模型连接在一起的功能。我正在尝试用生成器和鉴别器训练生成对抗网络作为lstm。这就是我需要函数

的原因
def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    return model

我用它来训练网络

g_loss = discriminator_on_generator.train_on_batch(noise, [1] * BATCH_SIZE)
discriminator.trainable = True
print("batch %d g_loss : %f" % (index, g_loss))

为了使最后一段训练代码能够工作,生成器和鉴别器应该能够合并在一起。但是,我的发生器的输出不能输入我的鉴别器,因为我需要在将它发送到鉴别器之前将一些数据添加到发生器的输出。我怎样才能在Keras中做到这一点,以便将鉴别器添加到发生器上?有没有办法在model.add函数中添加数据?我在Keras文档中找不到任何内容

1 个答案:

答案 0 :(得分:1)

试试这个:

def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    list_of_dicriminator_inputs = [generator]

    for _ in range(3):
        auxiliary_model = Sequential()
        auxiliary_model.add(Reshape((1, 100*100), input_shape=(100 * 100,)))
        list_of_dicriminator_inputs.append(auxiliary_model)

    extended_generator_output = Merge(list_of_dicriminator_inputs,
                                      mode="concat",
                                      concat_axis=1)
    model.add(extended_generator_output)
    discriminator.trainable = False
    model.add(discriminator)
    return model

为了实现这一点,您应该更改generator代码的这一行:

model.add(Reshape((1, 100 * 100), input_shape=(100*100,)))