Tensorflow中的二阶自定义渐变

时间:2018-09-12 18:33:19

标签: python tensorflow

我想在Tensorflow中覆盖函数梯度的梯度,并且我一直在玩y = x ** 5的玩具示例(但仍然无法使其正常工作)。

import tensorflow as tf
from tensorflow.python.framework import ops

@tf.custom_gradient
def f(x):
    y = x ** 5

    @tf.custom_gradient
    def grad(dy):
        yp = ops.get_gradient_function(y.op)(y.op, dy)[0]

        def grad2(d2y):
            ypp = ops.get_gradient_function(yp.op)(yp.op, d2y)[0]
            return ypp
        return yp, grad2
    return y, grad

tf.reset_default_graph()
x = tf.placeholder(tf.float32)
y = f(x)
with tf.Session() as sess:
    grad1 = tf.gradients(y, x)[0]
    grad2 = tf.gradients(grad1, x)
    print(sess.run(grad1, feed_dict={x: 2}))
    print(sess.run(grad2, feed_dict={x: 2}))

输出为:

80.0
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-176-5b0fe15eb6a1> in <module>()
     52     writer = tf.summary.FileWriter('logs', sess.graph)
     53     print(sess.run(grad1, feed_dict={x: 2}))
---> 54     print(sess.run(grad2, feed_dict={x: 2}))
     55     writer.close()
     56 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    875     try:
    876       result = self._run(None, fetches, feed_dict, options_ptr,
--> 877                          run_metadata_ptr)
    878       if run_metadata:
    879         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1083     # Create a fetch handler to take care of the structure of fetches.
   1084     fetch_handler = _FetchHandler(
-> 1085         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
   1086 
   1087     # Run request and get response.

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
    425     """
    426     with graph.as_default():
--> 427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
    428     self._fetches = []
    429     self._targets = []

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    243     elif isinstance(fetch, (list, tuple)):
    244       # NOTE(touts): This is also the code path for namedtuples.
--> 245       return _ListFetchMapper(fetch)
    246     elif isinstance(fetch, dict):
    247       return _DictFetchMapper(fetch)

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in __init__(self, fetches)
    350     """
    351     self._fetch_type = type(fetches)
--> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    354 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in <listcomp>(.0)
    350     """
    351     self._fetch_type = type(fetches)
--> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
    353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
    354 

/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in for_fetch(fetch)
    240     if fetch is None:
    241       raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
--> 242                                                                  type(fetch)))
    243     elif isinstance(fetch, (list, tuple)):
    244       # NOTE(touts): This is also the code path for namedtuples.

TypeError: Fetch argument None has invalid type <class 'NoneType'>

一阶梯度是正确的:5 *(x ** 4)= 5 *(2 ** 4)= 80 但是二阶梯度似乎返回None,因为它不是x的函数?有谁知道覆盖二阶梯度的正确方法?谢谢!

0 个答案:

没有答案