我正在尝试使用tf.gradients()函数返回一组值和一个模型之间的梯度,但是它总是给我带来错误。
它从sess.run()
行返回的错误是TypeError: Fetch argument None has invalid type <class 'NoneType'>
,但据我所知,我尚未指定任何提取参数,因此我不确定是什么原因引起的。
请帮助我了解错误的出处,谢谢。
(另外,我已经注意到,当我打印渐变时,它返回[None, None]
,这和它有关吗?)
import tensorflow as tf
class Model:
def __init__(self, input_dim, hidden_dim, name="model"):
with tf.variable_scope(name):
self.feature_vector_ = tf.placeholder(tf.float32,
shape=[None, input_dim],
name='feature_vector_')
with tf.variable_scope('layer_1'):
self.W_1 = tf.get_variable('W_1',
shape=[input_dim, hidden_dim],
initializer=tf.contrib.layers.xavier_initializer())
hidden_1 = tf.nn.relu(tf.matmul(self.feature_vector_, self.W_1), name='hidden_1')
with tf.variable_scope('layer_2'):
self.W_2 = tf.get_variable('W_2', shape=[hidden_dim, 1],
initializer=tf.contrib.layers.xavier_initializer())
self.value = tf.tanh(tf.matmul(hidden_1, self.W_2), name='value')
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=tf.get_variable_scope().name)
if __name__ == "__main__":
with tf.device("/cpu:0"):
model = Model(5,5,"test")
model_value = tf.placeholder(tf.float32, shape=(1))
gradients = tf.gradients(model_value, model.trainable_variables)
with tf.Session() as sess:
values = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5]
for i in range(0,len(values)):
grads = sess.run(gradients, feed_dict={model_value:[values[i]]})
以下是堆栈跟踪信息:
Traceback (most recent call last):
File "C:/Users/User/PycharmProjects/tf-grads/toy_gradients_example", line 35, in <module>
grads = sess.run(gradients, feed_dict={model_value:[values[i]]})
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 929, in run
run_metadata_ptr)
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 1137, in _run
self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 471, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 261, in for_fetch
return _ListFetchMapper(fetch)
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 370, in __init__
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 370, in <listcomp>
self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
File "C:\Users\User\PycharmProjects\tf-grads\venv\lib\site-packages\tensorflow\python\client\session.py", line 258, in for_fetch
type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>