tensorflow 中的自定义激活函数,具有 tanh 的可学习参数

时间:2021-06-06 20:28:50

标签: tensorflow tensorflow2.0 activation-function

我想在 tensorflow 中实现一个自定义的激活函数。这个激活函数的想法是它应该学习它的线性程度。使用以下函数。

tanh(x*w)/w  for w!= 0
x            for w = 0 

应该学习参数w。但是我不知道如何在 tensorflow 中实现这一点。

1 个答案:

答案 0 :(得分:0)

激活函数只是模型的一部分,所以这里是您描述的函数的代码。

import tensorflow as tf
from tensorflow.keras import Model

class MyModel(Model):
    def __init__(self):
        super().__init__()
        # Some layers
        self.W = tf.Variable(tf.constant([[0.1, 0.1], [0.1, 0.1]]))
        
    def call(self, x):
        # Some transformations with your layers
        x = tf.where(x==0, x, tf.tanh(self.W*x)/self.W)
        return x

所以,对于非零矩阵 MyModel()(tf.constant([[1.0, 2.0], [3.0, 4.0]])) 它返回

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0.9966799, 1.9737529],
       [2.913126 , 3.79949  ]], dtype=float32)>

对于零矩阵 MyModel()(tf.constant([[0.0, 0.0], [0.0, 0.0]])) 它返回零

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[0., 0.],
       [0., 0.]], dtype=float32)>