使用Tensorflow轻松检查自定义渐变

时间:2018-05-15 20:10:39

标签: python tensorflow deep-learning

我有一个非常简单的问题:我已经在Tensorflow中实现了一个复杂的自定义op及其渐变,假设前进是正确的,我想知道是否有一种简单的方法可以检查有限差分是否与您的自定义渐变相近不必以丑陋的方式重新实现它。我看到了函数B in the official doc,但源代码密集且难以阅读,我似乎无法找到任何其他相关问题或示例。
但是我确信我错过了一个超级简单的自成一体的例子? 修改 例如,如果我尝试:

tf.test.compute_gradient_error()

它抛出: AttributeError:'NoneType'对象没有属性'run' 看看gradient_checker.py我做错了什么?

1 个答案:

答案 0 :(得分:1)

所以我的问题是gradient_checker.py调用get_default_session()来获取它使用的会话,如果op没有显式连接到正在使用的会话,这显然不起作用,这通过确定范围来完成OP:

with sess.as_default_session():
  check=tf.test.compute_gradient_error()
  print check

还需要说明它需要这种方式的原因来自于检查直接是张量sess.run()的结果而不是像大多数张量流函数那样的图中的节点。