我想在Tensorflow中覆盖函数梯度的梯度,并且我一直在玩y = x ** 5
的玩具示例(但仍然无法使其正常工作)。
import tensorflow as tf
from tensorflow.python.framework import ops
@tf.custom_gradient
def f(x):
y = x ** 5
@tf.custom_gradient
def grad(dy):
yp = ops.get_gradient_function(y.op)(y.op, dy)[0]
def grad2(d2y):
ypp = ops.get_gradient_function(yp.op)(yp.op, d2y)[0]
return ypp
return yp, grad2
return y, grad
tf.reset_default_graph()
x = tf.placeholder(tf.float32)
y = f(x)
with tf.Session() as sess:
grad1 = tf.gradients(y, x)[0]
grad2 = tf.gradients(grad1, x)
print(sess.run(grad1, feed_dict={x: 2}))
print(sess.run(grad2, feed_dict={x: 2}))
输出为:
80.0
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-176-5b0fe15eb6a1> in <module>()
52 writer = tf.summary.FileWriter('logs', sess.graph)
53 print(sess.run(grad1, feed_dict={x: 2}))
---> 54 print(sess.run(grad2, feed_dict={x: 2}))
55 writer.close()
56
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
875 try:
876 result = self._run(None, fetches, feed_dict, options_ptr,
--> 877 run_metadata_ptr)
878 if run_metadata:
879 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1083 # Create a fetch handler to take care of the structure of fetches.
1084 fetch_handler = _FetchHandler(
-> 1085 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1086
1087 # Run request and get response.
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in for_fetch(fetch)
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
--> 245 return _ListFetchMapper(fetch)
246 elif isinstance(fetch, dict):
247 return _DictFetchMapper(fetch)
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in __init__(self, fetches)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in <listcomp>(.0)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py in for_fetch(fetch)
240 if fetch is None:
241 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
--> 242 type(fetch)))
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <class 'NoneType'>
一阶梯度是正确的:5 *(x ** 4)= 5 *(2 ** 4)= 80
但是二阶梯度似乎返回None
,因为它不是x的函数?有谁知道覆盖二阶梯度的正确方法?谢谢!