为什么tf.gradients()函数返回TypeError:提取参数None具有无效的类型?

时间:2019-03-28 23:52:03

标签: python tensorflow

我正在尝试使用tf.gradients()函数返回一组值和一个模型之间的梯度,但是它总是给我带来错误。

它从sess.run()行返回的错误是TypeError: Fetch argument None has invalid type <class 'NoneType'>,但据我所知,我尚未指定任何提取参数,因此我不确定是什么原因引起的。

请帮助我了解错误的出处,谢谢。

(另外,我已经注意到,当我打印渐变时,它返回[None, None],这和它有关吗?)

import tensorflow as tf

class Model:
    def __init__(self, input_dim, hidden_dim, name="model"):

        with tf.variable_scope(name):
            self.feature_vector_ = tf.placeholder(tf.float32,
                                                  shape=[None, input_dim],
                                                  name='feature_vector_')
            with tf.variable_scope('layer_1'):
                self.W_1 = tf.get_variable('W_1',
                                      shape=[input_dim, hidden_dim],
                                      initializer=tf.contrib.layers.xavier_initializer())
                hidden_1 = tf.nn.relu(tf.matmul(self.feature_vector_, self.W_1), name='hidden_1')

            with tf.variable_scope('layer_2'):
                self.W_2 = tf.get_variable('W_2', shape=[hidden_dim, 1],
                                      initializer=tf.contrib.layers.xavier_initializer())
                self.value = tf.tanh(tf.matmul(hidden_1, self.W_2), name='value')

            self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                         scope=tf.get_variable_scope().name)
if __name__ == "__main__":
    with tf.device("/cpu:0"):

        model = Model(5,5,"test")

        model_value = tf.placeholder(tf.float32, shape=(1))
        gradients = tf.gradients(model_value, model.trainable_variables)

    with tf.Session() as sess:
        values = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5]

        for i in range(0,len(values)):
            grads = sess.run(gradients, feed_dict={model_value:[values[i]]})

以下是堆栈跟踪信息:

Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/tf-grads/toy_gradients_example", line 35, in <module>
    grads = sess.run(gradients, feed_dict={model_value:[values[i]]})
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 929, in run
    run_metadata_ptr)
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 1137, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 471, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 370, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 258, in for_fetch
    type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

0 个答案:

没有答案