应用@ tf.custom_gradient

时间:2019-03-26 13:06:48

标签: python tensorflow

我正在实施 Lei等。 (2015)的《合理化神经预测》 神经网络以生成针对 NLP 问题的提取摘要。

该体系结构需要一个采样层,该采样层从文本输入中选择可用单词的子集。由于采样是不可区分的,因此我使用 @ tf.custom_gradient ,将0的'artificial'梯度分配给采样层。

问题在于,采样层上方的各层不再训练(看起来,网络下方的所有梯度流实际上已被消除)。

我的问题是如何避免这种影响并使网络忽略或跳过不可微分层?

我尝试了硬的S形激活(例如alpha = 60)以近似选择0到1之间的值,但这还不够好(更不用说优雅了)。

这是具有自定义渐变的图层

@tf.custom_gradient
def sampling(x):
    def grad(dy):
        return x*0+0
    dist = tfd.Bernoulli(probs=x,dtype = tf.float32)  
    return dist.sample(),grad

所以,这就是发生的情况:

  1. 神经网络第1层#不训练
  2. 神经网络第2层#不训练
  3. 具有自定义渐变的图层#不应该训练
  4. 下一层#训练良好
  5. 下一层#训练良好
  6. 下一层#训练良好
  7. 输出

那么我如何使1)和2)仍在训练?

0 个答案:

没有答案