在Tensorflow 2中开始培训需要很长时间

时间:2020-06-15 13:38:14

标签: python tensorflow tensorflow2.0

Tensorflow 2大约需要15分钟来制作其静态图(或在第一遍之前进行的操作)。此后的训练时间是正常的,但显然很难尝试等待15分钟的反馈。

生成器编码器和鉴别器是Keras模型中具有GRU单元的RNN(未展开)。

生成器解码器的定义和调用方式如下:

class GeneratorDecoder(tf.keras.layers.Layer):
def __init__(self, feature_dim):
    super(GeneratorDecoder, self).__init__()
    self.cell = tf.keras.layers.GRUCell(
        GRUI_DIM, activation='tanh', recurrent_activation='sigmoid',
        dropout=DROPOUT, recurrent_dropout=DROPOUT)
    self.batch_normalization = tf.keras.layers.BatchNormalization()
    self.dense = tf.keras.layers.Dense(
        feature_dim, activation='tanh')

@tf.function
def __call__(self, z, timesteps, training):
    # z has shape (batch_size, features)
    outputs = []
    output, state = z, z
    for i in range(timesteps):
        output, state = self.cell(inputs=output, states=state,
                                  training=training)
        dense_output = self.dense(
            self.batch_normalization(output))
        outputs.append(dense_output)
    return outputs

这是我的训练循环(mask_gt和missing_data变量是使用tf.cast强制转换的,因此应该已经是张量):

for it in tqdm(range(NO_ITERATIONS)):
   print(it)
   train_step()


@tf.function
def train_step():
    with tf.GradientTape(persistent=True) as tape:
        generator_output = generator(missing_data, training=True)
        imputed_data = get_imputed_data(missing_data, generator_output)
        mask_pred = discriminator(imputed_data)
        D_loss = discriminator.loss(mask_pred, mask_gt)
        G_loss = generator.loss(missing_data, mask_gt,
                                generator_output, mask_pred)
    gen_enc_grad = tape.gradient(
        G_loss, generator.encoder.trainable_variables)
    gen_dec_grad = tape.gradient(
        G_loss, generator.decoder.trainable_variables)
    disc_grad = tape.gradient(
        D_loss, discriminator.model.trainable_variables)
    del tape

    generator.optimizer.apply_gradients(
        zip(gen_enc_grad, generator.encoder.trainable_variables))
    generator.optimizer.apply_gradients(
        zip(gen_dec_grad, generator.decoder.trainable_variables))
    discriminator.optimizer.apply_gradients(
        zip(disc_grad, discriminator.model.trainable_variables))

请注意,在几秒钟内将打印出“ 0”,因此速度较慢的部分肯定不会更早。 这就是被称为的get_imputed_data函数:

def get_imputed_data(incomplete_series, generator_output):
    return tf.where(tf.math.is_nan(incomplete_series), generator_output, incomplete_series)

谢谢您的回答!希望我提供的代码足够多,可以使您了解问题所在。这是我阅读至少五年后第一次在这里发表文章:)

我使用Python 3.6和Tensorflow 2.1。

1 个答案:

答案 0 :(得分:1)

通过删除生成器和鉴别器的调用函数的tf.function装饰器解决了该问题。我在两个tf.function装饰函数中使用了单个全局python标量(迭代号)。这导致每次都创建一个新图(请参见警告in the tf.function docs)。

解决方案是删除使用的python变量或将它们转换为tensorflow变量。