我在这里遇到错误:
def train_model(training_data, model=False):
X = np.array([i[0] for i in training_data]).reshape(-1,len(training_data[0][0]),1)
y = [i[1] for i in training_data]
model = neural_network_model(input_size = len(X[0]))
model.fit({'input': X}, {'targets': y}, n_epoch=4,
snapshot_step=500, show_metric=True,
run_id='snek_learning')
return model
这是training_data数组的制作方法:
for data in game_memory:
if data[1] == 0: output = [1,0,0,0]
elif data[1] == 1: output = [0,1,0,0]
elif data[1] == 2: output = [0,0,1,0]
elif data[1] == 1: output = [0,0,0,1]
training_data.append([data[0], output])
我不确定我是否正确调整了形状,请帮助
编辑:
创建game_memory
:
for _ in range(initial_games):
score = 0
game_memory = []
prev_observation = []
for _ in range(goal_steps):
action = random.randrange(0,5)
observation, reward, done = Next_Frame(action)
if len(prev_observation) > 0 :
game_memory.append([prev_observation, action])
prev_observation = observation
score=reward
if done: break
其中observation
和prev_observation
始终是7个变量的列表
编辑2:
在传递到training_data
之前,已打印train_model
的10次迭代:
[False, False, False, True, False, True, False] [0, 0, 0, 1]
[False, False, False, True, False, True, False] [1, 0, 0, 0]
[False, False, False, True, False, True, False] [0, 0, 1, 0]
[False, False, False, True, False, True, False] [1, 0, 0, 0]
[False, False, False, True, False, True, False] [0, 1, 0, 0]
[False, False, False, True, False, True, False] [0, 0, 0, 1]
[False, False, False, True, False, True, False] [0, 0, 1, 0]
[False, False, False, True, False, True, False] [0, 1, 0, 0]
[False, False, False, True, False, True, False] [0, 0, 0, 1]
[False, False, False, True, False, False, True] [0, 1, 0, 0]
下一页编辑:
完整的错误日志:
Traceback (most recent call last):
File "c:/Users/Szymon/Documents/snake_network.py", line 273, in <module>
model = train_model(training_data)
File "c:/Users/Szymon/Documents/snake_network.py", line 266, in train_model
model.fit({'input': X}, {'targets': y}, n_epoch=4, snapshot_step=500, show_metric=True, run_id='snek_learning')
File "C:\Users\Szymon\Anaconda3\lib\site-packages\tflearn\models\dnn.py", line 216, in fit
callbacks=callbacks)
File "C:\Users\Szymon\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 339, in fit
show_metric)
File "C:\Users\Szymon\Anaconda3\lib\site-packages\tflearn\helpers\trainer.py", line 818, in _train
feed_batch)
File "C:\Users\Szymon\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 900, in run
run_metadata_ptr)
File "C:\Users\Szymon\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1104, in _run
np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
File "C:\Users\Szymon\Anaconda3\lib\site-packages\numpy\core\numeric.py", line 492, in asarray
return array(a, dtype, copy=False, order=order)
ValueError: setting an array element with a sequence.