在Tensorflow中,我建立了一个神经网络如下:
x = tf.placeholder(tf.float32)
x_ = tf.placeholder(tf.float32)
th = tf.placeholder(tf.float32)
th_ = tf.placeholder(tf.float32)
rlu_1 = tf.contrib.layers.fully_connected(inputs=tf.reshape([x,x_,th,th_],[1,4]),num_outputs=10)
# 4 state features: x, x_, th, th_
rlu_1.weights_initializer = tf.random_uniform(shape=[4],minval=-1,maxval=1) # is this 4 or 10?
rlu_1.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
rlu_2 = tf.contrib.layers.fully_connected(inputs=rlu_1,num_outputs=10) # hope that makes a copy
rlu_2.weights_initializer = tf.random_uniform(shape=[10],minval=-1,maxval=1)
rlu_2.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
Qvals = tf.contrib.layers.fully_connected(inputs=rlu_2,num_outputs=2)
Qvals.weights_initializer = tf.random_uniform(shape=[10],minval=-1,maxval=1)
Qvals.biases_initializer = tf.random_uniform(shape=[1],minval=-1,maxval=1)
Qvals.activation_fn = tf.identity
xt = tf.placeholder(tf.float32)
x_t = tf.placeholder(tf.float32)
tht = tf.placeholder(tf.float32)
th_t = tf.placeholder(tf.float32)
# I build a separate copy of the network here, using [xt,x_t,tht and th_t] as inputs
我使用以下代码运行会话:
observation = env.reset()
observation = [float(i) for i in observation]
prev_observation = observation
#print(observation)
reward = 1.0
tfreward = tf.constant(reward, dtype=tf.float32)
train_,nextAction = sess.run(train,tf.argmax(Qvals,0),
{x:prev_observation[0],x_:prev_observation[1],
th:prev_observation[2],th_:prev_observation[3],
xt:observation[0],x_t:observation[1],
tht:observation[2],th_t:observation[3]})
问题是我在上面代码的最后一行收到以下错误:
File "C:/Users/linna_t3vz49n/Documents/CS 486/a4/cartPole.py", line 236, in <module>
tht:observation[2],th_t:observation[3]})
File "C:\Users\linna_t3vz49n\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py", line 783, in run
compat.as_bytes(options.SerializeToString()))
AttributeError: 'dict' object has no attribute 'SerializeToString'
我该如何解决这个问题?
答案 0 :(得分:1)
尝试将[train,tf.argmax(Qvals,0)]
的列表发送到sess.run
train_,nextAction = sess.run([train,tf.argmax(Qvals,0)],
{x:prev_observation[0],x_:prev_observation[1],
th:prev_observation[2],th_:prev_observation[3],
xt:observation[0],x_t:observation[1],
tht:observation[2],th_t:observation[3]})