我正在与OpenAI健身房合作,培训一个演员评论网络,其中一个网络提供动作,第二个网络提供预期值。但是,当我尝试从网络中获取渐变时,我不断收到TypeError: Fetch argument None has invalid type <class 'NoneType'>
错误,以便稍后更新。它只会在我与评论家网络运行时或者如果我运行第二个演员网络时出现。我用不同的tf.variable_scope
值定义了它们并传递了相同的会话,所以在我看来它应该工作,我似乎无法弄清楚为什么它不会。我发现了其他帖子here,here和here,但他们没有解决我的问题。
我的网络是给出的(为了简洁起见,我删除了层和其他正在运行的方法,演员网络在这个抽象层次上几乎相同,只是一个不同的损失函数;如果认为有必要,我可以提供更多代码):
# Define critic network
class critic(object):
def __init__(self, sess, scope):
self.sess = sess
self.scope = scope
with tf.variable_scope(self.scope):
# Network inputs, outputs, rewards, optimizer, etc...
self.state = tf.placeholder(tf.float32, [None, self.n_inputs],
name='state')
self.returns = tf.placeholder(tf.float32, [None], name='returns')
# Single, linear layer
self.output = fully_connected(self.state, self.n_out,
activation_fn=None,
weights_initializer=None)
self.est_state_value = tf.squeeze(self.output)
# Define loss function
self.loss = tf.squared_difference(self.est_state_value, self.returns)
self.trainable_variables = tf.trainable_variables()
self.gradients = tf.gradients(self.loss, self.trainable_variables)
# Methods for prediction, updating, etc...
用于返回网络渐变的get_grads
方法导致了问题:
def get_grads(self, states, actions, returns):
grads = self.sess.run([self.gradients],
feed_dict={
self.state: states,
self.actions: actions,
self.returns: returns
})[0]
return grads
运行算法时,会在第二次get_grads
调用时抛出错误。
tf.reset_default_graph()
sess = tf.Session()
act = actor(sess, scope='actor')
crit = critic(sess, scope='critic')
init = tf.global_variables_initializer()
act.sess.run(init)
crit.sess.run(init)
# Randomized data for example
rewards = np.ones(10)
actions = np.random.choice([0, 1], 10)
states = np.random.normal(size=(10, 4))
act.get_grads(states, actions, rewards)
crit.get_grads(states, rewards)
这让我觉得也许这是由于两个网络之间的类似命名约定,所以我尝试在那里进行更改,使用两个单独的tf.Session()
值和其他东西,但问题仍然存在。如果我只运行一个网络 - 演员或评论家 - 一切都很好,并且学得很好。所以,我不确定这里发生了什么导致此错误或如何解决它。我很感激你们的帮助。
完整追溯:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-78-c56d39a21e63> in <module>()
13
14 act.get_grads(states, actions, rewards)
---> 15 crit.get_grads(states, rewards)
<ipython-input-76-031f8b9688f5> in get_grads(self, states, returns)
53 feed_dict={
54 self.state: states,
---> 55 self.returns: returns
56 })
57 return grads
...\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
...\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1120 # Create a fetch handler to take care of the structure of fetches.
1121 fetch_handler = _FetchHandler(
-> 1122 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1123
1124 # Run request and get response.
...\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
...\tensorflow\python\client\session.py in for_fetch(fetch)
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
--> 245 return _ListFetchMapper(fetch)
246 elif isinstance(fetch, dict):
247 return _DictFetchMapper(fetch)
...\tensorflow\python\client\session.py in __init__(self, fetches)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
...\tensorflow\python\client\session.py in <listcomp>(.0)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
...\tensorflow\python\client\session.py in for_fetch(fetch)
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
--> 245 return _ListFetchMapper(fetch)
246 elif isinstance(fetch, dict):
247 return _DictFetchMapper(fetch)
...\python\client\session.py in __init__(self, fetches)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
...\client\session.py in <listcomp>(.0)
350 """
351 self._fetch_type = type(fetches)
--> 352 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
353 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
354
...\client\session.py in for_fetch(fetch)
240 if fetch is None:
241 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
--> 242 type(fetch)))
243 elif isinstance(fetch, (list, tuple)):
244 # NOTE(touts): This is also the code path for namedtuples.
TypeError: Fetch argument None has invalid type <class 'NoneType'>
答案 0 :(得分:0)
虽然我在一个唯一的self.trainable_variables = tf.trainable_variables()
内调用了tf.variable_scope(self.scope)
,但我按顺序初始化网络的方式导致第一个网络正常初始化,然后第二个网络将所有可训练的变量分配给初始化后self.trainable_variables
。要修复它,我只需要通过将调用更改为:
self.trainable_variables = tf.trainable_variables(self.scope)