我知道这个问题已经问了很多,但是我仍然无法解决,我使用自己的图像数据集,并且在photoshop中裁剪了其中的一些图像,“如果它可以解决问题”然后每次我尝试训练网络时,都会出现此错误。这是我的代码:
def create_train_data():
training_data = []
for img in tqdm(os.listdir(TRAIN_DIR)):
path = os.path.join(TRAIN_DIR, img)
img_data = cv2.imread(path, 0)
try:
img_data = cv2.resize(img_data, (IMG_SIZE, IMG_SIZE))
except:
exc_type, exc_obj, tb = sys.exc_info()
f = tb.tb_frame
lineno = tb.tb_lineno
filename = img
linecache.checkcache(filename)
line = linecache.getline(filename, lineno, f.f_globals)
print('EXCEPTION IN ({}, LINE {} "{}"): {}'.format(filename, lineno, line.strip(), exc_obj))
training_data.append([np.array(img_data), create_label(img)])
shuffle(training_data)
np.save('women_train_data.npy', training_data)
return training_data
def create_test_data():
testing_data = []
for img in tqdm(os.listdir(TEST_DIR)):
path = os.path.join(TEST_DIR, img)
img_num = img.split('.')[1]
img_data = cv2.imread(path, 0)
try:
img_data = cv2.resize(img_data, (IMG_SIZE, IMG_SIZE))
except:
exc_type, exc_obj, tb = sys.exc_info()
f = tb.tb_frame
lineno = tb.tb_lineno
filename = img
linecache.checkcache(filename)
line = linecache.getline(filename, lineno, f.f_globals)
print('EXCEPTION IN ({}, LINE {} "{}"): {}'.format(filename, lineno,
line.strip(), exc_obj))
testing_data.append([np.array(img_data), create_label(img)])
np.save('women_test_data.npy', testing_data)
return testing_data
tf.reset_default_graph()
if os.path.exists('women_train_data.npy'):
train_data = np.load('women_train_data.npy')
else:
train_data = create_train_data()
if os.path.exists('women_test_data.npy'):
test_data = np.load('women_test_data.npy')
else:
test_data = create_test_data()
train = train_data
test = test_data
print (train.shape)
X_train = np.array([i[0] for i in train]).reshape(-1, IMG_SIZE, IMG_SIZE, 1)
y_train = [i[1] for i in train]
X_test = np.array([i[0] for i in test]).reshape(-1, IMG_SIZE, IMG_SIZE, 1)
y_test = [i[1] for i in test]
conv_input = input_data(shape=[None, IMG_SIZE, IMG_SIZE, 1], name='input')
conv1 = conv_2d(conv_input, 32, 5, activation='relu')
pool1 = max_pool_2d(conv1, 5)
conv2 = conv_2d(pool1, 64, 5, activation='relu')
pool2 = max_pool_2d(conv2, 5)
conv3 = conv_2d(pool2, 128, 5, activation='relu')
pool3 = max_pool_2d(conv3, 5)
conv4 = conv_2d(pool3, 64, 5, activation='relu')
pool4 = max_pool_2d(conv4, 5)
conv5 = conv_2d(pool4, 32, 5, activation='relu')
pool5 = max_pool_2d(conv5, 5)
fully_layer = fully_connected(pool5, 1024, activation='relu')
fully_layer = dropout(fully_layer, 0.8)
cnn_layers = fully_connected(fully_layer, 56, activation='softmax')
cnn_layers = regression(cnn_layers, optimizer='adam', learning_rate=LR,
loss='categorical_crossentropy', name='targets')
model = tflearn.DNN(cnn_layers, tensorboard_dir='log',
tensorboard_verbose=3)
if os.path.exists('women_model.tfl.meta'):
model.load('./model.tfl')
else:
model.fit({'input': X_train}, {'targets': y_train}, n_epoch=10,
validation_set=({'input': X_test}, {'targets': y_test}),
snapshot_step=500, show_metric=True, run_id=MODEL_NAME)
model.save('women_model.tfl')
我用Kaggel的cat vs dog数据集尝试了此代码,并且工作正常,我在不同的数据集上也遇到了这个问题,但在将其用于不同的项目但使用相同的代码并将其重新使用后,此代码突然起作用了,但是我没有完成培训,所以我不确定是否会遇到同样的问题, 我知道这可能来自包含内部列表长度不同的数组,但是所有数据都来自图像,所以我不 对输入有完全控制权,那么我该如何解决?
答案 0 :(得分:0)
最后,我在YouTube评论之一上找到了解决方案,事实证明,如果我将其中一张图片标记为错误,我应该添加一个带有零的硬编码矢量:
T
对不起,我没有经验,我觉得这个错误是令人误解的,或者与我不知道的数组长度问题有关,我希望以后能有所帮助