使用Tensorflow Estimator API进行混合精度训练

时间:2019-03-21 14:22:48

标签: python tensorflow tensorflow-estimator

有人使用过tensorflow估计器API进行混合精度训练吗?

我尝试将输入投射到tf.float16,然后将网络结果重新投射到tf.float32。为了缩小损失,我使用了tf.contrib.mixed_precision.LossScaleOptimizer。

我得到的错误消息相对没有信息:“试图将'x'转换为张量并失败。错误:不支持任何值”

1 个答案:

答案 0 :(得分:0)

我发现了问题:我使用tf.get_variable存储学习率。此变量没有梯度。普通的优化器不在乎,但是tf.contrib.mixed_precision.LossScaleOptimizer崩溃。因此,请确保未将这些变量添加到tf.GraphKeys.TRAINABLE_VARIABLES。