我正在尝试在Tensorflow中运行强化学习神经网络。该网络应该模拟从车站到车站的公交车。距离数组是所有停靠点之间的距离。
r是每一站的奖励,为1 /距离,以便距离越短得分越高。我将停止点与自身之间的距离更改为100,以使得分最低。
我正在尝试将奖励列表馈入网络,但无法使其正常工作。
我不断收到此错误消息:
ValueError: Tried to convert 'input' to a tensor and failed. Error: Argument must be a dense tensor: [array([[1.4126427]], dtype=float32)] - got shape [1, 1, 1], but wanted [1].
谢谢,我刚开始使用Tensorflow。
import numpy as np
distance = np.array([[100.0, 1.2666252330640149, 0.25730267246737104, 1.036047234307023, 1.3070229956494719, 1.4220012430080795, 0.6072094468614045, 0.5419515226848974, 1.8446239900559354, 0.8707271597265382, 0.8707271597265382, 1.5438160348042262, 0.5419515226848974, 1.4698570540708515, 1.3921690490988192, 0.8048477315102548, 1.4188937228091982, 1.755127408328154, 0.1752641392169049, 1.7955251709136109, 2.2591671845867, 0.3380981976382847, 1.6967060285891857, 0.5487880671224363, 0.7296457426973275],
[0.9254195152268491, 100.0, 1.603480422622747, 0.13238036047234306, 0.4033561218147918, 0.5183343691733996, 1.8147917961466749, 0.3834679925419515, 0.9409571162212554, 0.8719701678060907, 0.6668738346799254, 0.6401491609695463, 0.22871348663766314, 1.8116842759477938, 2.219390926041019, 1.0111870727159726, 1.5363579863269112, 1.3374766935985085, 0.4916096954630205, 2.1330018645121194, 1.5587321317588565, 0.32939714108141704, 1.2591671845866996, 0.7383467992541951, 0.9192044748290864],
[0.25730267246737104, 1.944686140459913, 100.0, 1.7141081417029211, 1.9850839030453697, 2.206960845245494, 0.3499067743940335, 1.2200124300807955, 2.2343070229956496, 0.6134244872591672, 0.8178993163455562, 2.144810441267868, 0.7998756991920447, 1.2044748290863891, 0.5239279055313859, 0.5475450590428838, 1.1535114978247358, 1.4897451833436917, 0.4331883157240522, 1.5301429459291487, 1.9937849596022374, 0.5960223741454319, 1.4313238036047233, 1.2268489745183344, 1.4077066500932256],
[1.1553760099440646, 0.22995649471721566, 1.833436917339963, 100.0, 0.2703542573026725, 0.3853325046612803, 1.7756370416407707, 0.6134244872591672, 0.8079552517091362, 0.7389683032939715, 0.5338719701678061, 0.507147296457427, 0.29459291485394656, 1.6786824114356744, 1.600994406463642, 0.8781852082038534, 1.4033561218147919, 1.2044748290863891, 0.7215661901802362, 2.0, 1.4257302672467371, 0.5593536357986327, 1.1261653200745805, 0.9683032939714109, 1.149160969546302],
[1.3927905531385956, 0.46737103791174645, 2.0708514605344934, 0.6065879428216283, 100.0, 0.20758234928527036, 1.504661280298322, 0.8508390304536979, 0.5369794903666874, 0.4679925419515227, 0.26289620882535736, 0.23617153511497826, 0.6960845245494096, 1.4077066500932256, 1.3300186451211933, 0.6072094468614045, 1.132380360472343, 0.9334990677439403, 0.9589807333747669, 1.7290242386575514, 1.1547545059042885, 0.7967681789931634, 0.8551895587321316, 1.2057178371659416, 1.3865755127408328],
[1.1845866998135488, 0.2591671845866998, 1.8626476072094469, 0.3915475450590429, 0.2262274704785581, 100.0 , 1.7315102548166563, 0.6426351771286514, 0.7638284648850218, 0.5059042883778744, 0.3008079552517091, 0.4630205096333126, 0.48788067122436296, 1.63455562461156, 1.5568676196395277, 0.6451211932877564, 1.3592293349906774, 1.1603480422622747, 0.7507768800497203, 1.9558732131758856, 1.3816034804226227, 0.5885643256681168, 1.0820385332504663, 0.997513983840895, 1.1783716594157863],
[0.6072094468614045, 1.6848974518334368, 0.3499067743940335, 1.8601615910503417, 1.6861404599129894, 1.425108763206961, 100.0, 1.4313238036047233, 1.4524549409571164, 0.4126786824114357, 0.6171535114978247, 1.362958359229335, 1.7688004972032318, 0.42262274704785585, 0.1740211311373524, 0.3467992541951523, 0.3716594157862026, 0.7078931013051585, 0.7830950901180858, 1.3399627097576134, 0.9291485394655066, 0.9459291485394655, 0.6494717215661902, 1.4381603480422622, 1.6190180236171534],
[0.5419515226848974, 0.7246737103791174, 1.2200124300807955, 0.49409571162212557, 0.7650714729645743, 0.8800497203231821, 1.4313238036047233, 100.0, 1.302672467371038, 1.2336855189558733, 1.0285891858297078, 1.0018645121193288, 0.3374766935985084, 1.9136109384711002, 1.8359229334990679, 1.6289620882535736, 1.8626476072094469, 2.1988812927284025, 0.4586699813548788, 2.2392790553138595, 2.7029210689869485, 0.2747047855811063, 2.1404599129894346, 0.3548788067122436, 0.5357364822871349],
[2.014294592914854, 1.052827843380982, 1.7489123679303915, 1.038533250466128, 0.8645121193287757, 0.7930391547545059, 0.9670602858918583, 1.4362958359229334, 100.0, 0.7874456183965195, 0.8483530142945929, 0.5413300186451212, 1.2815413300186451, 0.8701056556867619, 0.7924176507147296, 0.6476072094468615, 0.5947793660658794, 0.39589807333747673, 1.5444375388440024, 1.7290242386575514, 0.6171535114978247, 1.382224984462399, 0.3175885643256681, 1.7911746426351771, 1.9720323182100683],
[0.8707271597265382, 0.6793039154754505, 0.6134244872591672, 0.8116842759477937, 0.49844623990055936, 0.4195152268489746, 0.4126786824114357, 1.062771908017402, 0.9453076444996892, 100.0, 0.20447482908638906, 0.6444996892479801, 0.9080174021131138, 1.002486016159105, 0.5866998135487881, 0.13921690490988192, 0.7271597265382225, 0.5282784338098198, 0.7252952144188938, 1.603480422622747, 0.7495338719701677, 0.6550652579241766, 0.4698570540708515, 1.4176507147296458, 1.598508390304537],
[1.399627097576134, 0.4742075823492853, 0.8178993163455562, 0.6065879428216283, 0.29334990677439404, 0.21441889372280917, 0.6171535114978247, 0.8576755748912367, 0.7402113113735239, 0.20447482908638906, 100.0, 0.43940335612181475, 0.7029210689869484, 1.6109384711000623, 1.53325046612803, 0.343691733996271, 1.3356121814791797, 1.1367308887507768, 0.9658172778123059, 1.9322560596643878, 1.357986326911125, 0.8036047234307022, 1.0584213797389683, 1.2125543816034805, 1.3934120571783717],
[1.6476072094468612, 0.7221876942200124, 2.049720323182101, 0.8545680546923555, 0.5413300186451212, 0.46239900559353636, 1.2678682411435676, 1.105655686761964, 0.3001864512119329, 0.7228091982597887, 0.5177128651336234, 100.0, 0.9509011808576756, 1.170913610938471, 1.0932256059664387, 0.9484151646985706, 0.8955873213175886, 0.6967060285891858, 1.213797389683033, 1.4922311995027966, 0.9179614667495339, 1.0515848353014294, 0.6183965195773773, 1.4605344934742077, 1.641392169049099],
[0.7998756991920447, 0.4754505904288378, 1.477936606587943, 0.24487259167184589, 0.5158483530142945, 0.6308266003729024, 1.6892479801118707, 0.2579241765071473, 1.0534493474207582, 0.9844623990055936, 0.7793660658794282, 0.7526413921690491, 100.0, 2.1715351149782474, 2.093847110006215, 1.1236793039154755, 1.648850217526414, 1.4499689247980112, 0.3660658794282163, 2.4972032318210067, 1.6712243629583592, 0.20385332504661283, 1.3716594157862025, 0.612802983219391, 0.7936606587942822],
[1.4698570540708515, 1.5879428216283407, 1.2044748290863891, 1.7632069608452456, 1.5891858297078931, 1.3281541330018645, 0.42262274704785585, 1.9136109384711002, 1.3555003107520198, 1.002486016159105, 1.3834679925419515, 1.2660037290242385, 1.8166563082660037, 100.0, 0.2479801118707272, 0.8626476072094468, 0.2747047855811063, 0.6109384711000622, 1.810441267868241, 0.2815413300186451, 0.745183343691734, 1.9173399627097576, 0.5525170913610938, 1.9204474829086389, 2.10130515848353],
[1.3921690490988192, 1.5102548166563083, 0.5239279055313859, 1.6855189558732133, 1.5114978247358608, 1.2504661280298321, 0.1740211311373524, 1.8359229334990679, 1.2778123057799877, 0.5866998135487881, 1.3057799875699192, 1.1883157240522062, 1.7389683032939716, 0.2479801118707272, 100.0, 0.7849596022374145, 0.19701678060907396, 0.5332504661280298, 1.7327532628962088, 1.1068986948415165, 0.7545059042883778, 1.8396519577377253, 0.4748290863890615, 1.8427594779366065, 2.0236171535114975],
[0.8048477315102548, 0.8185208203853325, 0.5475450590428838, 0.9509011808576756, 0.6376631448104413, 0.5587321317588565, 0.3467992541951523, 1.201988812927284, 1.1330018645121194, 0.13921690490988192, 0.343691733996271, 1.043505282784338, 1.0472343070229957, 0.8626476072094468, 0.5208203853325046, 100.0, 0.5873213175885643, 0.3884400248601616, 0.8645121193287757, 1.7215661901802362, 0.6096954630205096, 0.7942821628340584, 0.33001864512119333, 1.6357986326911125, 1.8166563082660037],
[1.4188937228091982, 1.312616532007458, 1.1535114978247358, 1.487880671224363, 1.3138595400870106, 1.052827843380982, 0.3716594157862026, 1.8626476072094469, 1.0801740211311373, 0.7271597265382225, 1.108141702921069, 0.9906774394033562, 1.541330018645121, 0.2747047855811063, 0.19701678060907396, 0.5873213175885643, 100.0, 0.33561218147917965, 1.759477936606588, 1.1336233685518957, 0.5568676196395277, 1.642013673088875, 0.2771908017402113, 1.8694841516469858, 2.0503418272218767],
[1.755127408328154, 1.1137352392790554, 1.4897451833436917, 1.2889993784959601, 1.1149782473586078, 0.8539465506525793, 0.7078931013051585, 1.4972032318210067, 0.8812927284027345, 0.5282784338098198, 0.9092604101926663, 0.7917961466749535, 1.3424487259167186, 0.6109384711000622, 0.5332504661280298, 0.3884400248601616, 0.33561218147917965, 100.0, 2.0957116221255436, 0.6842759477936606, 0.22063393412057178, 1.4431323803604723, 0.07830950901180858, 2.2057178371659414, 2.3865755127408326],
[0.1752641392169049, 0.8421379738968303, 0.4331883157240522, 0.6115599751398384, 0.8825357364822871, 0.997513983840895, 0.7830950901180858, 0.4586699813548788, 1.420136730888751, 0.7252952144188938, 0.6948415164698571, 1.1193287756370416, 0.3660658794282163, 1.810441267868241, 1.7327532628962088, 0.8645121193287757, 1.759477936606588, 2.0957116221255436, 100.0, 2.1361093847110006, 2.599751398384089, 0.16221255438160348, 2.0372902423865757, 0.46550652579241764, 0.6463642013673089],
[1.7955251709136109, 1.9210689869484152, 1.5301429459291487, 2.053449347420758, 1.7402113113735238, 1.6612802983219392, 1.2815413300186451, 2.2392790553138595, 1.7719080174021131, 1.9216904909881916, 1.716594157862026, 1.586699813548788, 2.1497824735860784, 0.2815413300186451, 1.1068986948415165, 1.5860783095090119, 1.1336233685518957, 0.6842759477936606, 2.1361093847110006, 100.0, 0.4630205096333126, 2.2504661280298324, 1.2560596643878186, 2.2461155997513984, 2.4269732753262896],
[2.2591671845867, 1.7339962709757615, 1.9937849596022374, 1.8663766314481045, 1.5531385954008703, 1.4742075823492853, 0.9291485394655066, 2.117464263517713, 1.1025481665630827, 0.7495338719701677, 1.5295214418893721, 1.399627097576134, 1.9627097576134245, 0.745183343691734, 0.7545059042883778, 0.6096954630205096, 0.5568676196395277, 0.22063393412057178, 2.2256059664387817, 0.4630205096333126, 100.0, 2.063393412057178, 0.2995649471721566, 2.7097576134244874, 2.8906152889993786],
[0.3380981976382847, 0.6799254195152269, 0.5960223741454319, 0.4493474207582349, 0.7203231821006837, 0.8353014294592915, 1.7899316345556247, 0.35860783095090115, 1.2579241765071474, 0.6550652579241766, 0.6246115599751397, 0.9571162212554382, 0.20385332504661283, 2.2722187694220013, 2.194530764449969, 0.7942821628340584, 2.2212554381603478, 1.6544437538844001, 0.16221255438160348, 2.5978868862647606, 1.875699192044748, 100.0, 1.5761342448725917, 0.7134866376631448, 0.8943443132380361],
[1.6967060285891857, 1.0354257302672467, 1.4313238036047233, 1.2106898694841517, 1.0366687383467992, 0.7756370416407706, 0.6494717215661902, 1.4188937228091982, 0.8029832193909261, 0.4698570540708515, 0.8309509011808577, 0.7134866376631448, 1.2641392169049097, 0.5525170913610938, 0.4748290863890615, 0.33001864512119333, 0.2771908017402113, 0.07830950901180858, 1.5270354257302672, 1.4114356743318832, 0.2995649471721566, 1.3648228713486639, 100.0, 2.1472964574269735, 2.3281541330018647],
[0.5487880671224363, 1.0795525170913611, 1.2268489745183344, 0.8489745183343692, 1.119950279676818, 1.2349285270354258, 1.4381603480422622, 0.3548788067122436, 1.6575512740832814, 1.5885643256681168, 1.3834679925419515, 1.3567433188315723, 0.6923555003107521, 1.9204474829086389, 1.8427594779366065, 1.6357986326911125, 1.8694841516469858, 2.2057178371659414, 0.46550652579241764, 2.2461155997513984, 2.7097576134244874, 0.6295835922933498, 2.1472964574269735, 100.0, 0.1808576755748912],
[0.7756370416407706, 1.3064014916096953, 1.4536979490366688, 1.0758234928527035, 1.3467992541951521, 1.46177750155376, 1.6650093225605966, 0.581727781230578, 1.884400248601616, 1.8154133001864512, 1.610316967060286, 1.583592293349907, 0.9192044748290864, 2.1472964574269735, 2.069608452454941, 1.8626476072094469, 2.0963331261653204, 2.4325668116842762, 0.6923555003107521, 2.472964574269733, 2.9366065879428214, 0.8564325668116842, 2.3741454319453075, 0.22684897451833436, 100.0]]).astype("float32")
r = 1/distance
import tensorflow as tf
import random
BATCH_SIZE = 30
num_stops = len(r)
observations = tf.placeholder('float32', shape=[None, num_stops]) # Current game states : r[stop], r[next_stop], r[third_stop]
actions = tf.placeholder('int32',shape=[None]) # 0 - num-stops for actions taken
rewards = tf.placeholder('float32',shape=[None]) # +1, -1 with discounts
# Model
Y = tf.layers.dense(observations, 200, activation=tf.nn.relu)
Ylogits = tf.layers.dense(Y, num_stops)
# sample an action from predicted probabilities
sample_op = tf.random.categorical(logits=Ylogits, num_samples=1)
# loss
cross_entropies = tf.losses.softmax_cross_entropy(onehot_labels=tf.one_hot(actions,num_stops), logits=Ylogits)
loss = tf.reduce_sum(rewards * cross_entropies)
# training operation
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001, decay=.99)
train_op = optimizer.minimize(loss)
# Run model
visited_stops = []
steps = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Start at a random stop, initialize done to false
current_stop = random.randint(0, len(r) - 1)
done = False
# reset everything
while not done: # play a game in x steps
observations_list = []
actions_list = []
rewards_list = []
# List all stops and their scores
observation = r[current_stop]
# Add the stop to a list of non-visited stops if it isn't
# already there
if current_stop not in visited_stops:
visited_stops.append(current_stop)
# decide where to go
action = sess.run(sample_op, feed_dict={observations: [observation]})
# play it, output next state, reward if we got a point, and whether the game is over
#game_state, reward, done, info = pong_sim.step(action)
new_stop = int(action)
reward = r[current_stop][action]
if len(visited_stops) == num_stops:
done = True
if steps >= BATCH_SIZE:
done = True
steps += 1
observations_list.append(observation)
actions_list.append(action)
rewards_list.append(reward)
print(rewards_list)
#rewards_list = np.reshape(rewards, [-1, 25])
current_stop = new_stop
#processed_rewards = discount_rewards(rewards, args.gamma)
#processed_rewards = normalize_rewards(rewards, args.gamma)
tf.squeeze(rewards_list, axis=1)
sess.run(train_op, feed_dict={observations: [observations_list],
actions: [actions_list],
rewards: [rewards_list]})