如何强制Tensorflow在float16下运行?

时间:2019-06-16 03:29:20

标签: json tensorflow keras

我通过定义由keras的tf后端和一些tf的张量运算符本身编写的新类,通过Keras使用自定义激活函数来构建顺序模型。我将自定义激活功能放在../keras/advanced_activation.py中。

我打算使用float16精度运行它。没有自定义功能,我可以使用以下内容轻松在float32和float16之间进行选择:

        if self.precision == 'float16':
            K.set_floatx('float16')
            K.set_epsilon(1e-4)
        else:
            K.set_floatx('float32')
            K.set_epsilon(1e-7)

但是,当将自定义函数包含到我的模型中时,即使我选择了float16,tf仍会保留在float32中。我知道默认情况下tf是在flat32下运行的,所以我的问题是:

  1. 同一文件中还有几个内置的激活函数,Keras如何使它们在float16下运行,以便我可以执行相同的操作?有一个tf方法tf.dtypes.cast(...),我可以在自定义函数中使用它来强制tf吗?这些内置函数中没有这样的转换。

  2. 或者,如何通过使用tf作为后端的Keras强制tf直接在float16下运行?

非常感谢。

1 个答案:

答案 0 :(得分:0)

我通过调试得到了答案。教训是

首先,tf.dtypes.cast(...)有效。

第二,我可以在自定义激活函数中指定第二个参数,以指示cast(...)的数据类型。以下是相关代码

第三,我们不需要tf.constant来指示这些常量的数据类型

第四,我得出的结论是,在custom_activation.py中添加自定义函数是定义我们自己的层/激活的最简单方法,只要它在各处都是可区分的,或者至少是逐段可区分的,并且在接合处不间断。

# Quadruple Piece-Wise Constant Function
class MyFunc(Layer):


    def __init__(self, sharp=100, DataType = 'float32', **kwargs):
        super(MyFunc, self).__init__(**kwargs)
        self.supports_masking = True
        self.sharp = K.cast_to_floatx(sharp)
        self.DataType = DataType

    def call(self, inputs):

        inputss = tf.dtypes.cast(inputs, dtype=self.DataType)
        orig = inputss
        # some calculations
        return # my_results

    def get_config(self):
        config = {'sharp': float(self.sharp), 
                  'DataType': self.DataType}
        base_config = super(MyFunc, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

感谢@ y.selivonchyk与我进行有价值的讨论,感谢@Yolo Swaggins的贡献。