我实施了一个生成器来提供培训流程,但是fit_generator
会抛出此错误:
检查目标时出错:预期lambda_2有4个维度, 但是有阵形(200,1)
似乎函数在某个时刻切换X和y,因为(200,1)是“y”形,而不是“X”形。
如果我使用下面的代码测试生成器,它可以正常工作:
for i in range(32):
train = next(train_generator)
print(train[0].shape)
但是fit_generator
会抛出错误。
这是我的代码:
import os
import csv
samples = []
with open('data/driving_log.csv') as csvfile:
reader = csv.reader(csvfile)
for line in reader:
samples.append(line)
from sklearn.model_selection import train_test_split
train_samples, validation_samples = train_test_split(samples, test_size=0.2)
import cv2
import numpy as np
import sklearn
def generator(samples, batch_size=32):
num_samples = len(samples)
while 1: # Loop forever so the generator never terminates
sklearn.utils.shuffle(samples)
for offset in range(0, num_samples, batch_size):
batch_samples = samples[offset:offset+batch_size]
images = []
angles = []
for batch_sample in batch_samples:
name = 'data\\'+batch_sample[0].split('\\')[-1]
center_image = cv2.imread(name)
center_angle = float(batch_sample[3])
if not center_image is None:
images.append(center_image)
angles.append(center_angle)
# trim image to only see section with road
X_train = np.array(images)
y_train = np.array(angles)
yield sklearn.utils.shuffle(X_train, y_train)
# compile and train the model using the generator function
train_generator = generator(train_samples, batch_size=int(len(train_samples)/32))
validation_generator = generator(validation_samples, batch_size=int(len(validation_samples)/32))
ch, row, col = 3, 160, 320 # Trimmed image format
from keras.models import Sequential
from keras.layers import Lambda
model = Sequential()
# Preprocess incoming data, centered around zero with small standard deviation
model.add(Lambda(lambda x: x/127.5 - 1.,
input_shape=(row, col, ch),
output_shape=(row, col, ch)))
#model.add(... finish defining the rest of your model architecture here ...)
model.compile(loss='mse', optimizer='adam')
model.fit_generator(train_generator,
steps_per_epoch=len(train_samples) / 32, validation_data=validation_generator,
validation_steps=len(validation_samples)/32, epochs=3)
我有什么想法可以解决这个问题吗?
答案 0 :(得分:0)
这是加载图片时出错,名称未正确定义。由于cv2.imread(name)
没有引发错误,因为它没有找到图像,只返回一个None对象,该方法返回一个空变量,导致网络上的错误。