我对CNN的培训有问题。 我的数据集目前有1000个包含图像和列表的numpy数组。
这是用于训练CNN的代码:
import numpy as np
from alexnet import alexnet
# image resolution
WIDTH = 160
HEIGHT = 120
LR = 1e-3
EPOCHS = 8
MODEL_NAME = 'model-{}-{}-{}-epochs.model'.format(LR, 'alexnet',EPOCHS)
model = alexnet(WIDTH, HEIGHT, LR)
hm_data = 22
for i in range(EPOCHS):
for i in range(1,hm_data+1):
train_data = np.load('trainmodel_test1.npy'.format(i))
train = train_data[:-100]
test = train_data[-100:]
X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,1)
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
test_y = [i[1] for i in test]
# ----------- error happens here: ----------------
model.fit({'input': X}, {'targets': Y}, n_epoch=EPOCHS, validation_set=({'input': test_x}, {'targets': test_y}),
snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
# ------------------------------------------------
model.save(MODEL_NAME)
它启动:
Training samples: 2700 # -----> expected: 900 !!!
Validation samples: 300 # -----> expected: 100
--
但现在问题是它将数据集* 3相乘并且之后抱怨索引高于总数据集范围:
Exception in thread Thread-3:
Traceback (most recent call last):
File "..\Python\Python35\Lib\threading.py", line 914, in _bootstrap_inner
self.run()
File "..\Python\Python35\Lib\threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "..\0.1\venv\lib\site-packages\tflearn\data_flow.py", line 187, in fill_feed_dict_queue
data = self.retrieve_data(batch_ids)
File "..\0.1\venv\lib\site-packages\tflearn\data_flow.py", line 222, in retrieve_data
utils.slice_array(self.feed_dict[key], batch_ids)
File "..\0.1\venv\lib\site-packages\tflearn\utils.py", line 187, in slice_array
return X[start]
IndexError: index 1331 is out of bounds for axis 0 with size 900
提示?
〜由sentdex修改代码