如何将tf.keras与bfloat16结合使用

时间:2019-05-13 16:44:18

标签: python tensorflow keras google-compute-engine google-cloud-tpu

我正在尝试使用混合精度在tpu上运行tf.keras模型。我想知道如何使用bfloat16混合精度构建keras模型。是这样吗?

with tf.contrib.tpu.bfloat16_scope():
    inputs = tf.keras.layers.Input(shape=(2,), dtype=tf.bfloat16)
    logits = tf.keras.layers.Dense(2)(inputs)

logits = tf.cast(logits, tf.float32)
model = tf.keras.models.Model(inputs=inputs, outputs=logits)
model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
        model,
        strategy=tf.contrib.tpu.TPUDistributionStrategy(
            tf.contrib.cluster_resolver.TPUClusterResolver(tpu='my_tpu_name')
        )
    )

1 个答案:

答案 0 :(得分:1)

您可以使用以下代码使用Keras bfloat16Mixed Precision计算和float16变量)来构建float32模型。

tf.keras.mixed_precision.experimental.set_policy('infer_float32_vars')

model = tf.keras.Sequential([
    tf.keras.layers.Inputs(input_shape=(2, ), dtype=tf.float16),    
    tf.keras.layers.Lambda(lambda x: tf.cast(x, 'float32')),
    tf.keras.layers.Dense(10)])

model.compile(optimizer=tf.keras.optimizers.Adam(.001),
              loss='mean_absolute_error', metrics=[])

model.fit(.............)

一旦构建并训练了模型,我们可以使用以下步骤保存模型:

tf.keras.experimental.export_saved_model(model, path_to_save_model)

我们可以使用以下代码加载保存的混合精度Keras模型:

new_model = tf.keras.experimental.load_from_saved_model(path_to_save_model)
new_model.summary()

如果您认为此答案有用,请接受此答案和/或对其进行投票。谢谢。