tensorflow 0.10.0rc版本支持float16吗?

时间:2016-08-04 08:21:37

标签: tensorflow

为了减少张量,我在模型中用dytpe=tf.float16定义了所有变量,然后定义了优化器:

optimizer = tf.train.AdamOptimizer(self.learning_rate)
self.compute_gradients = optimizer.compute_gradients(self.mean_loss_reg)
train_adam_op = optimizer.apply_gradients(self.compute_gradients, global_step=self.global_step)

一切正常!但是在我运行train_adam_op之后,渐变和变量在python中是nan。我喜欢如果apply_gradients() API支持tf.float16类型?为什么apply_gradients() session.run()调用{{1}}之后我得到了南。

1 个答案:

答案 0 :(得分:4)

与32位浮点数相比,fp16的动态范围相当有限。因此,它们很容易溢出或下溢,这通常会导致您遇到的NaN。

您可以在模型中插入一些check_numerics操作,以帮助查明在fp16上执行时变得不稳定的特定操作。

例如,您可以按如下方式包装L2丢失操作,以检查其结果是否适合fp16

A = tf.l2_loss(some_tensor)

变为

A = tf.check_numerics(tf.l2_loss(some_tensor), "found the root cause")

最常见的溢出和下溢源是exp(),log(),以及各种分类原语,所以我会开始寻找它。

一旦你弄清楚哪个操作序列有问题,你可以使用tf.cast()将序列的输入转换为32位浮点数,然后使用32位浮点数来更新模型以执行该序列结果回到fp16。