如何使用autograph
“魔术”提高下面代码的执行速度?
从此处提供代码:https://www.tensorflow.org/alpha/tutorials/generative/dcgan
尤其是https://www.tensorflow.org/alpha/tutorials/generative/dcgan#train_the_model这部分:
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
并考虑将tf.data.Dataset
包装到@tf.function
调用中的好处。如此处所示:Does wrapping tf.data.Dataset into tf.function improve performance?。
我尝试过的事情:
由于train()
中的所有其他python代码,无法将@tf.function
包装到train
中。但是:记录,检查点等操作通常需要此代码。
问题: