多个网络产生Tensorflow TypeError:Fetch参数None具有无效类型<class'nonetype'=“”>

时间:2018-03-10 19:57:44

标签: python tensorflow neural-network reinforcement-learning

我正在与OpenAI健身房合作,培训一个演员评论网络,其中一个网络提供动作,第二个网络提供预期值。但是,当我尝试从网络中获取渐变时,我不断收到TypeError: Fetch argument None has invalid type <class 'NoneType'>错误,以便稍后更新。它只会在我与评论家网络运行时或者如果我运行第二个演员网络时出现。我用不同的tf.variable_scope值定义了它们并传递了相同的会话,所以在我看来它应该工作,我似乎无法弄清楚为什么它不会。我发现了其他帖子hereherehere,但他们没有解决我的问题。

我的网络是给出的(为了简洁起见,我删除了层和其他正在运行的方法,演员网络在这个抽象层次上几乎相同,只是一个不同的损失函数;如果认为有必要,我可以提供更多代码):

# Define critic network
class critic(object):    
    def __init__(self, sess, scope):

        self.sess = sess
        self.scope = scope
        with tf.variable_scope(self.scope):
          # Network inputs, outputs, rewards, optimizer, etc...
          self.state = tf.placeholder(tf.float32, [None, self.n_inputs],
                                    name='state')
          self.returns = tf.placeholder(tf.float32, [None], name='returns')
          # Single, linear layer
          self.output = fully_connected(self.state, self.n_out, 
                                      activation_fn=None,
                                      weights_initializer=None)

          self.est_state_value = tf.squeeze(self.output)
          # Define loss function
          self.loss = tf.squared_difference(self.est_state_value, self.returns)
          self.trainable_variables = tf.trainable_variables()
          self.gradients = tf.gradients(self.loss, self.trainable_variables)

    # Methods for prediction, updating, etc...

用于返回网络渐变的get_grads方法导致了问题:

def get_grads(self, states, actions, returns):
    grads = self.sess.run([self.gradients], 
        feed_dict={
        self.state: states,
        self.actions: actions,
        self.returns: returns
        })[0]
    return grads        

运行算法时,会在第二次get_grads调用时抛出错误。

tf.reset_default_graph()

sess = tf.Session()
act = actor(sess, scope='actor')
crit = critic(sess, scope='critic')
init = tf.global_variables_initializer()
act.sess.run(init)
crit.sess.run(init)
# Randomized data for example
rewards = np.ones(10)
actions = np.random.choice([0, 1], 10)
states = np.random.normal(size=(10, 4))

act.get_grads(states, actions, rewards)
crit.get_grads(states, rewards)

这让我觉得也许这是由于两个网络之间的类似命名约定,所以我尝试在那里进行更改,使用两个单独的tf.Session()值和其他东西,但问题仍然存在。如果我只运行一个网络 - 演员或评论家 - 一切都很好,并且学得很好。所以,我不确定这里发生了什么导致此错误或如何解决它。我很感激你们的帮助。

完整追溯:

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-78-c56d39a21e63> in <module>()
         13 
         14 act.get_grads(states, actions, rewards)
    ---> 15 crit.get_grads(states, rewards)

    <ipython-input-76-031f8b9688f5> in get_grads(self, states, returns)
         53             feed_dict={
         54             self.state: states,
    ---> 55             self.returns: returns
         56             })
         57         return grads

    ...\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
        903     try:
        904       result = self._run(None, fetches, feed_dict, options_ptr,
    --> 905                          run_metadata_ptr)
        906       if run_metadata:
        907         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

    ...\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
       1120     # Create a fetch handler to take care of the structure of fetches.
       1121     fetch_handler = _FetchHandler(
    -> 1122         self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
       1123 
       1124     # Run request and get response.

    ...\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
        425     """
        426     with graph.as_default():
    --> 427       self._fetch_mapper = _FetchMapper.for_fetch(fetches)
        428     self._fetches = []
        429     self._targets = []

    ...\tensorflow\python\client\session.py in for_fetch(fetch)
        243     elif isinstance(fetch, (list, tuple)):
        244       # NOTE(touts): This is also the code path for namedtuples.
    --> 245       return _ListFetchMapper(fetch)
        246     elif isinstance(fetch, dict):
        247       return _DictFetchMapper(fetch)

    ...\tensorflow\python\client\session.py in __init__(self, fetches)
        350     """
        351     self._fetch_type = type(fetches)
    --> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        354 

   ...\tensorflow\python\client\session.py in <listcomp>(.0)
        350     """
        351     self._fetch_type = type(fetches)
    --> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        354 

   ...\tensorflow\python\client\session.py in for_fetch(fetch)
        243     elif isinstance(fetch, (list, tuple)):
        244       # NOTE(touts): This is also the code path for namedtuples.
    --> 245       return _ListFetchMapper(fetch)
        246     elif isinstance(fetch, dict):
        247       return _DictFetchMapper(fetch)

    ...\python\client\session.py in __init__(self, fetches)
        350     """
        351     self._fetch_type = type(fetches)
    --> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        354 

   ...\client\session.py in <listcomp>(.0)
        350     """
        351     self._fetch_type = type(fetches)
    --> 352     self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
        353     self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
        354 

    ...\client\session.py in for_fetch(fetch)
        240     if fetch is None:
        241       raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
    --> 242                                                                  type(fetch)))
        243     elif isinstance(fetch, (list, tuple)):
        244       # NOTE(touts): This is also the code path for namedtuples.

    TypeError: Fetch argument None has invalid type <class 'NoneType'>

1 个答案:

答案 0 :(得分:0)

虽然我在一个唯一的self.trainable_variables = tf.trainable_variables()内调用了tf.variable_scope(self.scope),但我按顺序初始化网络的方式导致第一个网络正常初始化,然后第二个网络将所有可训练的变量分配给初始化后self.trainable_variables。要修复它,我只需要通过将调用更改为:

来定义每个网络的变量时明确
self.trainable_variables = tf.trainable_variables(self.scope)