在大型文件很多的库中,您如何在Keras模型中的“ mixed_precision”策略中放哪里?

时间:2020-08-03 22:45:20

标签: tensorflow keras

使用Tensorflow 2.3,我有一个keras模型,其函数定义跨越十几个文件。有一个主文件可以运行整个过程并进行拟合,但是当然每个文件都有其自己的import语句。该模型在build.py中构建,在compile.py中编译,然后从master.py运行。我知道,如果我想以混合精度进行训练,则需要在编译之前执行以下操作(或一些较长/较短的变化):

from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

我的问题是,是否只需要在master.py文件,定义模型的文件(build.py,已编译(compiled.py)或所有文件中声明它?与定义模型有关吗?

1 个答案:

答案 0 :(得分:1)

您将要在构建模型的文件内设置混合精度策略,在本例中为build.py。原因是默认情况下将从全局策略中推断层的数据类型。请注意,这是在构建图层时发生的。如果您的build.py包含一个返回模型的函数,那么您可以像以前一样进行操作,或者在调用该函数以构建模型的文件中设置全局策略(master.py)。