Tensorflow 2.0渐变胶带不适用于Keras模型

时间:2019-04-27 19:01:57

标签: python-3.x tensorflow tensorflow2.0

我正在使用Tensorflow 2.0构建自动编码器并在MNIST数据集上进行训练。代码可以总结为:

def build_encoder(inp):
  x = Conv2D(16, (3, 3), activation='relu', padding='same')(inp)
  x = MaxPooling2D((2, 2), padding='same')(x)
  x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
  x = MaxPooling2D((2, 2), padding='same')(x)
  x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
  encoded = MaxPooling2D((2, 2), padding='same')(x)

  return encoded

def build_decoder(z):
  x = Conv2D(8, (3, 3), activation='relu', padding='same')(z)
  x = UpSampling2D((2, 2))(x)
  x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
  x = UpSampling2D((2, 2))(x)
  x = Conv2D(16, (3, 3), activation='relu')(x)
  x = UpSampling2D((2, 2))(x)
  decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
  return decoded    

inp = Input(shape=((28,28,1)))
encoder = build_encoder(inp)
decoder = build_decoder(encoder)
model = Model(inputs=inp, outputs=decoder)

这是火车的一部分:

for epoch in range(1, epochs + 1):
  for step, x in enumerate(training_dataset):

      with tf.GradientTape() as tape:
          # Forward pass
          x_reconstruction_logits = model(x)
          reconstruction_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_reconstruction_logits)
          reconstruction_loss = tf.reduce_sum(reconstruction_loss) / batch_size

      gradients = tape.gradient(reconstruction_loss, model.trainable_variables) 
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))

      if (step + 1) % 50 == 0:
          print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}"
                .format(epoch + 1, num_epochs, step + 1, num_batches, float(reconstruction_loss)))

输出

Epoch[2/55], Step [50/600], Reconst Loss: 695.5870
Epoch[2/55], Step [100/600], Reconst Loss: 695.5870
Epoch[2/55], Step [150/600], Reconst Loss: 695.5870
Epoch[2/55], Step [200/600], Reconst Loss: 695.5870
Epoch[2/55], Step [250/600], Reconst Loss: 695.5870
Epoch[2/55], Step [300/600], Reconst Loss: 695.5870
Epoch[2/55], Step [350/600], Reconst Loss: 695.5870
Epoch[2/55], Step [400/600], Reconst Loss: 695.5870

问题:损失保持恒定,如果我打印出来,则梯度为零。我在这里做错了什么?

0 个答案:

没有答案