ValueError:尝试将“输入”转换为张量,但失败

时间:2019-03-14 14:52:27

标签: python tensorflow

我正在尝试在Tensorflow中运行强化学习神经网络。该网络应该模拟从车站到车站的公交车。距离数组是所有停靠点之间的距离。

r是每一站的奖励,为1 /距离,以便距离越短得分越高。我将停止点与自身之间的距离更改为100,以使得分最低。

我正在尝试将奖励列表馈入网络,但无法使其正常工作。

我不断收到此错误消息:

ValueError: Tried to convert 'input' to a tensor and failed. Error: Argument must be a dense tensor: [array([[1.4126427]], dtype=float32)] - got shape [1, 1, 1], but wanted [1].

谢谢,我刚开始使用Tensorflow。

  import numpy as np
    distance =  np.array([[100.0,                         1.2666252330640149,            0.25730267246737104,           1.036047234307023,             1.3070229956494719,            1.4220012430080795,            0.6072094468614045,            0.5419515226848974,            1.8446239900559354,            0.8707271597265382,            0.8707271597265382,            1.5438160348042262,            0.5419515226848974,            1.4698570540708515,            1.3921690490988192,            0.8048477315102548,            1.4188937228091982,            1.755127408328154,             0.1752641392169049,            1.7955251709136109,            2.2591671845867,               0.3380981976382847,            1.6967060285891857,            0.5487880671224363,            0.7296457426973275],            
                      [0.9254195152268491,            100.0,                         1.603480422622747,             0.13238036047234306,           0.4033561218147918,            0.5183343691733996,            1.8147917961466749,            0.3834679925419515,            0.9409571162212554,            0.8719701678060907,            0.6668738346799254,            0.6401491609695463,            0.22871348663766314,           1.8116842759477938,            2.219390926041019,             1.0111870727159726,            1.5363579863269112,            1.3374766935985085,            0.4916096954630205,            2.1330018645121194,            1.5587321317588565,            0.32939714108141704,           1.2591671845866996,            0.7383467992541951,            0.9192044748290864],           
                      [0.25730267246737104,           1.944686140459913,             100.0,                         1.7141081417029211,            1.9850839030453697,            2.206960845245494,             0.3499067743940335,            1.2200124300807955,            2.2343070229956496,            0.6134244872591672,            0.8178993163455562,            2.144810441267868,             0.7998756991920447,            1.2044748290863891,            0.5239279055313859,            0.5475450590428838,            1.1535114978247358,            1.4897451833436917,            0.4331883157240522,            1.5301429459291487,            1.9937849596022374,            0.5960223741454319,            1.4313238036047233,            1.2268489745183344,            1.4077066500932256],            
                      [1.1553760099440646,            0.22995649471721566,           1.833436917339963,             100.0,                         0.2703542573026725,            0.3853325046612803,            1.7756370416407707,            0.6134244872591672,            0.8079552517091362,            0.7389683032939715,            0.5338719701678061,            0.507147296457427,             0.29459291485394656,           1.6786824114356744,            1.600994406463642,             0.8781852082038534,            1.4033561218147919,            1.2044748290863891,            0.7215661901802362,            2.0,                           1.4257302672467371,            0.5593536357986327,            1.1261653200745805,            0.9683032939714109,            1.149160969546302],             
                      [1.3927905531385956,            0.46737103791174645,           2.0708514605344934,            0.6065879428216283,            100.0,                         0.20758234928527036,           1.504661280298322,             0.8508390304536979,            0.5369794903666874,            0.4679925419515227,            0.26289620882535736,           0.23617153511497826,           0.6960845245494096,            1.4077066500932256,            1.3300186451211933,            0.6072094468614045,            1.132380360472343,             0.9334990677439403,            0.9589807333747669,            1.7290242386575514,            1.1547545059042885,            0.7967681789931634,            0.8551895587321316,            1.2057178371659416,            1.3865755127408328],            
                      [1.1845866998135488,            0.2591671845866998,            1.8626476072094469,            0.3915475450590429,            0.2262274704785581,            100.0           ,              1.7315102548166563,            0.6426351771286514,            0.7638284648850218,            0.5059042883778744,            0.3008079552517091,            0.4630205096333126,            0.48788067122436296,           1.63455562461156,              1.5568676196395277,            0.6451211932877564,            1.3592293349906774,            1.1603480422622747,            0.7507768800497203,            1.9558732131758856,            1.3816034804226227,            0.5885643256681168,            1.0820385332504663,            0.997513983840895,             1.1783716594157863],            
                      [0.6072094468614045,            1.6848974518334368,            0.3499067743940335,            1.8601615910503417,            1.6861404599129894,            1.425108763206961,             100.0,                         1.4313238036047233,            1.4524549409571164,            0.4126786824114357,            0.6171535114978247,            1.362958359229335,             1.7688004972032318,            0.42262274704785585,           0.1740211311373524,            0.3467992541951523,            0.3716594157862026,            0.7078931013051585,            0.7830950901180858,            1.3399627097576134,            0.9291485394655066,            0.9459291485394655,            0.6494717215661902,            1.4381603480422622,            1.6190180236171534],            
                      [0.5419515226848974,            0.7246737103791174,            1.2200124300807955,            0.49409571162212557,           0.7650714729645743,            0.8800497203231821,            1.4313238036047233,            100.0,                         1.302672467371038,             1.2336855189558733,            1.0285891858297078,            1.0018645121193288,            0.3374766935985084,            1.9136109384711002,            1.8359229334990679,            1.6289620882535736,            1.8626476072094469,            2.1988812927284025,            0.4586699813548788,            2.2392790553138595,            2.7029210689869485,            0.2747047855811063,            2.1404599129894346,            0.3548788067122436,            0.5357364822871349],            
                      [2.014294592914854,             1.052827843380982,             1.7489123679303915,            1.038533250466128,             0.8645121193287757,            0.7930391547545059,            0.9670602858918583,            1.4362958359229334,            100.0,                         0.7874456183965195,            0.8483530142945929,            0.5413300186451212,            1.2815413300186451,            0.8701056556867619,            0.7924176507147296,            0.6476072094468615,            0.5947793660658794,            0.39589807333747673,           1.5444375388440024,            1.7290242386575514,            0.6171535114978247,            1.382224984462399,             0.3175885643256681,            1.7911746426351771,            1.9720323182100683],            
                      [0.8707271597265382,            0.6793039154754505,            0.6134244872591672,            0.8116842759477937,            0.49844623990055936,           0.4195152268489746,            0.4126786824114357,            1.062771908017402,             0.9453076444996892,            100.0,                         0.20447482908638906,           0.6444996892479801,            0.9080174021131138,            1.002486016159105,             0.5866998135487881,            0.13921690490988192,           0.7271597265382225,            0.5282784338098198,            0.7252952144188938,            1.603480422622747,             0.7495338719701677,            0.6550652579241766,            0.4698570540708515,            1.4176507147296458,            1.598508390304537],             
                      [1.399627097576134,             0.4742075823492853,            0.8178993163455562,            0.6065879428216283,            0.29334990677439404,           0.21441889372280917,           0.6171535114978247,            0.8576755748912367,            0.7402113113735239,            0.20447482908638906,           100.0,                         0.43940335612181475,           0.7029210689869484,            1.6109384711000623,            1.53325046612803,              0.343691733996271,             1.3356121814791797,            1.1367308887507768,            0.9658172778123059,            1.9322560596643878,            1.357986326911125,             0.8036047234307022,            1.0584213797389683,            1.2125543816034805,            1.3934120571783717],            
                      [1.6476072094468612,            0.7221876942200124,            2.049720323182101,             0.8545680546923555,            0.5413300186451212,            0.46239900559353636,           1.2678682411435676,            1.105655686761964,             0.3001864512119329,            0.7228091982597887,            0.5177128651336234,            100.0,                         0.9509011808576756,            1.170913610938471,             1.0932256059664387,            0.9484151646985706,            0.8955873213175886,            0.6967060285891858,            1.213797389683033,             1.4922311995027966,            0.9179614667495339,            1.0515848353014294,            0.6183965195773773,            1.4605344934742077,            1.641392169049099],             
                      [0.7998756991920447,            0.4754505904288378,            1.477936606587943,             0.24487259167184589,           0.5158483530142945,            0.6308266003729024,            1.6892479801118707,            0.2579241765071473,            1.0534493474207582,            0.9844623990055936,            0.7793660658794282,            0.7526413921690491,            100.0,                         2.1715351149782474,            2.093847110006215,             1.1236793039154755,            1.648850217526414,             1.4499689247980112,            0.3660658794282163,            2.4972032318210067,            1.6712243629583592,            0.20385332504661283,           1.3716594157862025,            0.612802983219391,             0.7936606587942822],            
                      [1.4698570540708515,            1.5879428216283407,            1.2044748290863891,            1.7632069608452456,            1.5891858297078931,            1.3281541330018645,            0.42262274704785585,           1.9136109384711002,            1.3555003107520198,            1.002486016159105,             1.3834679925419515,            1.2660037290242385,            1.8166563082660037,            100.0,                         0.2479801118707272,            0.8626476072094468,            0.2747047855811063,            0.6109384711000622,            1.810441267868241,             0.2815413300186451,            0.745183343691734,             1.9173399627097576,            0.5525170913610938,            1.9204474829086389,            2.10130515848353],              
                      [1.3921690490988192,            1.5102548166563083,            0.5239279055313859,            1.6855189558732133,            1.5114978247358608,            1.2504661280298321,            0.1740211311373524,            1.8359229334990679,            1.2778123057799877,            0.5866998135487881,            1.3057799875699192,            1.1883157240522062,            1.7389683032939716,            0.2479801118707272,            100.0,                         0.7849596022374145,            0.19701678060907396,           0.5332504661280298,            1.7327532628962088,            1.1068986948415165,            0.7545059042883778,            1.8396519577377253,            0.4748290863890615,            1.8427594779366065,            2.0236171535114975],            
                      [0.8048477315102548,            0.8185208203853325,            0.5475450590428838,            0.9509011808576756,            0.6376631448104413,            0.5587321317588565,            0.3467992541951523,            1.201988812927284,             1.1330018645121194,            0.13921690490988192,           0.343691733996271,             1.043505282784338,             1.0472343070229957,            0.8626476072094468,            0.5208203853325046,            100.0,                         0.5873213175885643,            0.3884400248601616,            0.8645121193287757,            1.7215661901802362,            0.6096954630205096,            0.7942821628340584,            0.33001864512119333,           1.6357986326911125,            1.8166563082660037],            
                      [1.4188937228091982,            1.312616532007458,             1.1535114978247358,            1.487880671224363,             1.3138595400870106,            1.052827843380982,             0.3716594157862026,            1.8626476072094469,            1.0801740211311373,            0.7271597265382225,            1.108141702921069,             0.9906774394033562,            1.541330018645121,             0.2747047855811063,            0.19701678060907396,           0.5873213175885643,            100.0,                         0.33561218147917965,           1.759477936606588,             1.1336233685518957,            0.5568676196395277,            1.642013673088875,             0.2771908017402113,         1.8694841516469858,            2.0503418272218767],            
                      [1.755127408328154,             1.1137352392790554,            1.4897451833436917,            1.2889993784959601,            1.1149782473586078,            0.8539465506525793,            0.7078931013051585,            1.4972032318210067,            0.8812927284027345,            0.5282784338098198,            0.9092604101926663,            0.7917961466749535,            1.3424487259167186,            0.6109384711000622,            0.5332504661280298,            0.3884400248601616,            0.33561218147917965,           100.0,                         2.0957116221255436,            0.6842759477936606,            0.22063393412057178,           1.4431323803604723,            0.07830950901180858,           2.2057178371659414,            2.3865755127408326],            
                      [0.1752641392169049,            0.8421379738968303,            0.4331883157240522,            0.6115599751398384,            0.8825357364822871,            0.997513983840895,             0.7830950901180858,            0.4586699813548788,            1.420136730888751,             0.7252952144188938,            0.6948415164698571,            1.1193287756370416,            0.3660658794282163,            1.810441267868241,             1.7327532628962088,            0.8645121193287757,            1.759477936606588,             2.0957116221255436,            100.0,                         2.1361093847110006,            2.599751398384089,             0.16221255438160348,           2.0372902423865757,            0.46550652579241764,           0.6463642013673089],            
                      [1.7955251709136109,            1.9210689869484152,            1.5301429459291487,            2.053449347420758,             1.7402113113735238,            1.6612802983219392,            1.2815413300186451,            2.2392790553138595,            1.7719080174021131,            1.9216904909881916,            1.716594157862026,             1.586699813548788,             2.1497824735860784,            0.2815413300186451,            1.1068986948415165,            1.5860783095090119,            1.1336233685518957,            0.6842759477936606,            2.1361093847110006,            100.0,                         0.4630205096333126,            2.2504661280298324,            1.2560596643878186,            2.2461155997513984,            2.4269732753262896],            
                      [2.2591671845867,               1.7339962709757615,            1.9937849596022374,            1.8663766314481045,            1.5531385954008703,            1.4742075823492853,            0.9291485394655066,            2.117464263517713,             1.1025481665630827,            0.7495338719701677,            1.5295214418893721,            1.399627097576134,             1.9627097576134245,            0.745183343691734,             0.7545059042883778,            0.6096954630205096,            0.5568676196395277,            0.22063393412057178,           2.2256059664387817,            0.4630205096333126,            100.0,                         2.063393412057178,             0.2995649471721566,            2.7097576134244874,            2.8906152889993786],            
                      [0.3380981976382847,            0.6799254195152269,            0.5960223741454319,            0.4493474207582349,            0.7203231821006837,            0.8353014294592915,            1.7899316345556247,            0.35860783095090115,           1.2579241765071474,            0.6550652579241766,            0.6246115599751397,            0.9571162212554382,            0.20385332504661283,           2.2722187694220013,            2.194530764449969,             0.7942821628340584,            2.2212554381603478,            1.6544437538844001,            0.16221255438160348,           2.5978868862647606,            1.875699192044748,             100.0,                         1.5761342448725917,            0.7134866376631448,            0.8943443132380361],            
                      [1.6967060285891857,            1.0354257302672467,            1.4313238036047233,            1.2106898694841517,            1.0366687383467992,            0.7756370416407706,            0.6494717215661902,            1.4188937228091982,            0.8029832193909261,            0.4698570540708515,            0.8309509011808577,            0.7134866376631448,            1.2641392169049097,            0.5525170913610938,            0.4748290863890615,            0.33001864512119333,           0.2771908017402113,            0.07830950901180858,           1.5270354257302672,            1.4114356743318832,            0.2995649471721566,            1.3648228713486639,            100.0,                         2.1472964574269735,            2.3281541330018647],            
                      [0.5487880671224363,            1.0795525170913611,            1.2268489745183344,            0.8489745183343692,            1.119950279676818,             1.2349285270354258,            1.4381603480422622,            0.3548788067122436,            1.6575512740832814,            1.5885643256681168,            1.3834679925419515,            1.3567433188315723,            0.6923555003107521,            1.9204474829086389,            1.8427594779366065,            1.6357986326911125,            1.8694841516469858,            2.2057178371659414,            0.46550652579241764,           2.2461155997513984,            2.7097576134244874,            0.6295835922933498,            2.1472964574269735,            100.0,                         0.1808576755748912],            
                      [0.7756370416407706,            1.3064014916096953,            1.4536979490366688,            1.0758234928527035,            1.3467992541951521,            1.46177750155376,              1.6650093225605966,            0.581727781230578,             1.884400248601616,             1.8154133001864512,            1.610316967060286,             1.583592293349907,             0.9192044748290864,            2.1472964574269735,            2.069608452454941,             1.8626476072094469,            2.0963331261653204,            2.4325668116842762,            0.6923555003107521,            2.472964574269733,             2.9366065879428214,            0.8564325668116842,            2.3741454319453075,            0.22684897451833436,           100.0]]).astype("float32")                           

    r = 1/distance  


    import tensorflow as tf
    import random
    BATCH_SIZE = 30
    num_stops = len(r)



    observations = tf.placeholder('float32', shape=[None, num_stops]) # Current game states : r[stop], r[next_stop], r[third_stop]
    actions = tf.placeholder('int32',shape=[None])  # 0 - num-stops for actions taken
    rewards = tf.placeholder('float32',shape=[None])  # +1, -1 with discounts

# Model
    Y = tf.layers.dense(observations, 200, activation=tf.nn.relu)
    Ylogits = tf.layers.dense(Y, num_stops)

    # sample an action from predicted probabilities
    sample_op = tf.random.categorical(logits=Ylogits, num_samples=1)


    # loss
    cross_entropies = tf.losses.softmax_cross_entropy(onehot_labels=tf.one_hot(actions,num_stops), logits=Ylogits)

    loss = tf.reduce_sum(rewards * cross_entropies)

    # training operation
    optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001, decay=.99)
    train_op = optimizer.minimize(loss)


# Run model

    visited_stops = []
    steps = 0

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        # Start at a random stop, initialize done to false
        current_stop = random.randint(0, len(r) - 1)
        done = False

        # reset everything    
        while not done: # play a game in x steps   

            observations_list = []
            actions_list = []
            rewards_list = []

            # List all stops and their scores
            observation = r[current_stop]

            # Add the stop to a list of non-visited stops if it isn't
            # already there
            if current_stop not in visited_stops:
                visited_stops.append(current_stop)

            # decide where to go
            action = sess.run(sample_op, feed_dict={observations: [observation]})

            # play it, output next state, reward if we got a point, and whether the game is over
            #game_state, reward, done, info = pong_sim.step(action)
            new_stop = int(action)


            reward = r[current_stop][action]

            if len(visited_stops) == num_stops:
                done = True

            if steps >= BATCH_SIZE:
                done = True

            steps += 1

            observations_list.append(observation)
            actions_list.append(action)
            rewards_list.append(reward)

            print(rewards_list)




            #rewards_list = np.reshape(rewards, [-1, 25])
            current_stop = new_stop

        #processed_rewards = discount_rewards(rewards, args.gamma)
        #processed_rewards = normalize_rewards(rewards, args.gamma)

        tf.squeeze(rewards_list, axis=1) 

        sess.run(train_op, feed_dict={observations: [observations_list],
                                 actions: [actions_list],
                                 rewards: [rewards_list]})

0 个答案:

没有答案