我需要如下所示的自定义激活功能:
这是我如何使用tensorflow实现它:
import tensorflow as tf
sess = tf.Session()
def s_lamda_activation(f, lam):
positive = tf.nn.relu(f - lam)
positive = positive * (f/positive)
positive = tf.where(tf.is_nan(positive), tf.zeros_like(positive), positive)
negative = tf.nn.relu((-f) - lam)
negative = negative * (f/negative)
negative = tf.where(tf.is_nan(negative), tf.zeros_like(negative), negative)
return positive + negative
a = tf.constant([[1,2,3,4,5,10,-10,14,-20],[-100,-2,-3,-4,-5,-10,10,-14,-20]], dtype=tf.float32)
a = s_lamda_activation(a, 5)
print(sess.run(a))
输出:
[[ 0. 0. 0. 0. 0. 10. -10. 14. -20.]
[-100. 0. 0. 0. 0. -10. 10. -14. -20.]]
但是,tf.where
可能会导致一些渐变问题,并且使用此实现不会使损耗减少。
我删除了tf.where
,并将代码更改为:
import tensorflow as tf
sess = tf.Session()
def s_lamda_activation(f, lam):
positive = tf.nn.relu(f - lam)
negative = tf.nn.relu((-f) - lam)
return positive - negative
a = tf.constant([[1,2,3,4,5,10,-10,14,-20],[-100,-2,-3,-4,-5,-10,10,-14,-20]], dtype=tf.float32)
a = s_lamda_activation(a, 5)
print(sess.run(a))
输出:
[[ 0. 0. 0. 0. 0. 5. -5. 9. -15.]
[-95. 0. 0. 0. 0. -5. 5. -9. -15.]]
此实现工作正常,并且损耗函数正在按预期减少。但是此实现与上面公式化的原始激活函数不同。对我正确有效地实现功能有什么建议吗? tf.where
会产生渐变问题吗?
非常感谢您的帮助!
答案 0 :(得分:0)
问题是您没有正确使用tf.where()
来实现激活功能。您可以使用tf.gradients
来查看渐变,如下所示:
import tensorflow as tf
...
result = s_lamda_activation(a, 5)
grad = tf.gradients(result,a)
with tf.Session() as sess:
print(sess.run(result))
print(sess.run(grad))
[[ 0. 0. 0. 0. 0. 10. -10. 14. -20.]
[-100. 0. 0. 0. 0. -10. 10. -14. -20.]]
[array([[nan, nan, nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan, nan, nan]], dtype=float32)]
正确用法很简单:
import tensorflow as tf
def s_lamda_activation(f, lam):
return tf.where(tf.greater(tf.abs(f),lam),f,tf.zeros_like(f))
a = tf.constant([[1,2,3,4,5,10,-10,14,-20],[-100,-2,-3,-4,-5,-10,10,-14,-20]], dtype=tf.float32)
result = s_lamda_activation(a, 5)
grad = tf.gradients(result,a)
with tf.Session() as sess:
print(sess.run(result))
print(sess.run(grad))
[[ 0. 0. 0. 0. 0. 10. -10. 14. -20.]
[-100. 0. 0. 0. 0. -10. 10. -14. -20.]]
[array([[0., 0., 0., 0., 0., 1., 1., 1., 1.],
[1., 0., 0., 0., 0., 1., 1., 1., 1.]], dtype=float32)]