尺寸必须相等,但对于'model_53 / sequential_48 / dense_142 / MatMul'(op:'MatMul'),输入形状为[2,1],[2,75]的尺寸必须为1和2

时间:2018-11-27 04:20:25

标签: python tensorflow keras

    data = [K.stack([K.mean(data),K.std(data)])]

    data = K.transpose(data)

我想计算均值和标准差,然后将它们传递给鉴别器,但这不起作用。

GAN()类:     def init (自己):         self.original_dim = 1000         self.latent_dim = 1000

    optimizer = Adam(0.000000002, 0.5)

    # Build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy',
        optimizer=optimizer,
        metrics=['accuracy'])

    # Build the generator
    self.generator = self.build_generator()

    # The generator takes noise as input and generates imgs
    z = Input(shape=(self.latent_dim,))
    data = self.generator(z)
    data = [K.stack([K.mean(data),K.std(data)])]

    data = K.transpose(data)

    # For the combined model we will only train the generator
    self.discriminator.trainable = False

    # The discriminator takes generated images as input and determines validity
    validity = self.discriminator(data)

    # The combined model  (stacked generator and discriminator)
    # Trains the generator to fool the discriminator
    self.combined = Model(z, validity)
    self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


def build_generator(self):

    model = Sequential()
    model.add(Dense(150, input_dim=self.latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(150))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(self.original_dim))
    model.add(LeakyReLU(alpha=0.01))

    model.summary()

    noise = Input(shape=(self.latent_dim,))
    img = model(noise)

    return Model(noise, img)

def build_discriminator(self):

    model = Sequential()
    model.add(Dense(75, input_dim=2))
    model.add(Dense(75))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    data = Input(shape=(2,))
    validity = model(data)

    return Model(data, validity)

0 个答案:

没有答案