所以我收集了一个包含100,000个输入的数据集,其中包括游戏中的帧(元素[0])和玩家的移动(元素[1])(W,A,D)。该数据集将用于训练卷积神经网络,因此“AI”将能够在整个环境中自行导航。然后对数据集进行混洗,以便平衡它,以避免受到偏见训练的CNN。
但是我在尝试执行培训过程时遇到了这个错误,即:
---------------------------------Exception in thread Thread-3:
Traceback (most recent call last):
File "C:\Python64\lib\threading.py", line 916, in _bootstrap_inner
self.run()
File "C:\Python64\lib\threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "C:\Python64\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
data = self.retrieve_data(batch_ids)
File "C:\Python64\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
utils.slice_array(self.feed_dict[key], batch_ids)
File "C:\Python64\lib\site-packages\tflearn\utils.py", line 187, in slice_array
return X[start]
IndexError: index 39977 is out of bounds for axis 0 with size 17512
我已经在网上查了一下,大多数人说它可能是一个循环迭代错误(这可能是显而易见的),但是即使在尝试迭代直到数组结束时(使用len()函数),此错误仍会显示。
请在此处找到我的代码:
import numpy as np
from alexnetCNN import alexnet
WIDTH = 80
HEIGHT = 60
#learning rate
LR = 1e-3
EPOCHS = 8
MODEL_NAME = 'prototype-movement-{}-{}-{}-epochs.model'.format(LR, 'alexnet', EPOCHS)
model = alexnet(WIDTH, HEIGHT, LR)
train_data = np.load('E:\\Dataset\\training_data_balanced.npy')
train = train_data[:-800]
test = train_data[-800:]
#feature sets
#0 element refers to the frames recorded in the training set
#reshape gives a new shape to an array without changing its data
#--https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html
X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,1)
#labels
#1 element refers to the movements of the player such as W, A, D
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
test_y = [i[1] for i in test]
#sample testing
#-- The following method (model.fit) seems to be executing the mentioned error
model.fit({'input': X}, {'targets': Y}, n_epoch=EPOCHS, validation_set=({'input': test_x}, {'targets': test_y}),
snapshot_step=800, show_metric=True, run_id=MODEL_NAME)
# tensorboard --logdir=foo:C:/Python64/myScripts/modelLog
#when fitting is done, save the model
model.save(MODEL_NAME)
正如我在注释行中所说,model.fit()
方法似乎正在执行错误。非常感谢您的帮助!