Tensorflow:计算渐变w.r.t.子张量

时间:2017-06-05 20:55:20

标签: python tensorflow

让v成为Tensor。如果我计算另一个Tensor的梯度w.r.t到v一切正常,即

grads = tf.gradients(loss_func, v)

工作正常。

但是,当我想将渐变w.r.t计算为单个元素或v的任何子赋值时,我得到一个错误,即

grads = tf.gradients(loss_func, v[0,0])
grads = tf.gradients(loss_func, v[:,1:])

产生以下错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/henning/anaconda/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 866, in runfile
    execfile(filename, namespace)
  File "/Users/henning/anaconda/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)
  File "/Users/henning/pflow/testing.py", line 89, in <module>
    theta = sess.run(grads, feed_dict={P:P_inp, Q:Q_inp})
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 952, in _run
    fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 408, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 230, in for_fetch
    return _ListFetchMapper(fetch)
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 337, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 337, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/Users/henning/anaconda/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 227, in for_fetch
    (fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

我做错了什么?

1 个答案:

答案 0 :(得分:1)

我找到了问题的解决方案。

我发现最优雅的做法是“构建”#39; v来自常量和变量,然后计算渐变w.r.t.变量,即

v_free = tf.Variable(shape)
v_notfree = tf.constant(other_shape)
v = tf.concat([v_notfree, v_free])
loss_func = some function of v
grads = tf.gradients(loss_func, v_free)