TensorFlow:AttributeError:'dict'对象没有属性'SerializeToString'

时间:2017-07-20 19:23:47

标签: python tensorflow

在Tensorflow中,我建立了一个神经网络如下:

x = tf.placeholder(tf.float32)
x_ = tf.placeholder(tf.float32)
th = tf.placeholder(tf.float32)
th_ = tf.placeholder(tf.float32)


rlu_1 = tf.contrib.layers.fully_connected(inputs=tf.reshape([x,x_,th,th_],[1,4]),num_outputs=10)

# 4 state features: x, x_, th, th_
rlu_1.weights_initializer = tf.random_uniform(shape=[4],minval=-1,maxval=1) # is this 4 or 10?
rlu_1.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
rlu_2 = tf.contrib.layers.fully_connected(inputs=rlu_1,num_outputs=10) # hope that makes a copy
rlu_2.weights_initializer = tf.random_uniform(shape=[10],minval=-1,maxval=1)
rlu_2.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
Qvals = tf.contrib.layers.fully_connected(inputs=rlu_2,num_outputs=2)
Qvals.weights_initializer = tf.random_uniform(shape=[10],minval=-1,maxval=1)
Qvals.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
Qvals.activation_fn = tf.identity


xt = tf.placeholder(tf.float32)
x_t = tf.placeholder(tf.float32)
tht = tf.placeholder(tf.float32)
th_t = tf.placeholder(tf.float32)

# I build a separate copy of the network here, using [xt,x_t,tht and th_t] as inputs    

我使用以下代码运行会话:

observation = env.reset()
observation = [float(i) for i in observation]
prev_observation = observation
#print(observation)
reward = 1.0
tfreward = tf.constant(reward, dtype=tf.float32)

train_,nextAction = sess.run(train,tf.argmax(Qvals,0),
                              {x:prev_observation[0],x_:prev_observation[1],
                              th:prev_observation[2],th_:prev_observation[3],   
                              xt:observation[0],x_t:observation[1],
                              tht:observation[2],th_t:observation[3]})

问题是我在上面代码的最后一行收到以下错误:

  File "C:/Users/linna_t3vz49n/Documents/CS 486/a4/cartPole.py", line 236, in <module>
    tht:observation[2],th_t:observation[3]})

  File "C:\Users\linna_t3vz49n\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 783, in run
    compat.as_bytes(options.SerializeToString()))

AttributeError: 'dict' object has no attribute 'SerializeToString'

我该如何解决这个问题?

1 个答案:

答案 0 :(得分:1)

尝试将[train,tf.argmax(Qvals,0)]的列表发送到sess.run

train_,nextAction = sess.run([train,tf.argmax(Qvals,0)],
                          {x:prev_observation[0],x_:prev_observation[1],
                          th:prev_observation[2],th_:prev_observation[3],   
                          xt:observation[0],x_t:observation[1],
                          tht:observation[2],th_t:observation[3]})