model.fit_generator:检查目标时出错:预期lambda_2有4个维度,但得到的数组有形状(200,1)

时间:2017-06-29 18:46:54

标签: keras

我实施了一个生成器来提供培训流程,但是fit_generator会抛出此错误:

  

检查目标时出错:预期lambda_2有4个维度,   但是有阵形(200,1)

似乎函数在某个时刻切换X和y,因为(200,1)是“y”形,而不是“X”形。

如果我使用下面的代码测试生成器,它可以正常工作:

for i in range(32):    
    train = next(train_generator)
    print(train[0].shape)

但是fit_generator会抛出错误。

这是我的代码:

import os
import csv

samples = []
with open('data/driving_log.csv') as csvfile:
    reader = csv.reader(csvfile)
    for line in reader:
        samples.append(line)

from sklearn.model_selection import train_test_split
train_samples, validation_samples = train_test_split(samples, test_size=0.2)

import cv2
import numpy as np
import sklearn

def generator(samples, batch_size=32):
    num_samples = len(samples)
    while 1: # Loop forever so the generator never terminates
        sklearn.utils.shuffle(samples)
        for offset in range(0, num_samples, batch_size):
            batch_samples = samples[offset:offset+batch_size]

            images = []
            angles = []
            for batch_sample in batch_samples:
                name = 'data\\'+batch_sample[0].split('\\')[-1]
                center_image = cv2.imread(name)
                center_angle = float(batch_sample[3])

                if not center_image is None:
                    images.append(center_image)
                    angles.append(center_angle)



            # trim image to only see section with road
            X_train = np.array(images)
            y_train = np.array(angles)

            yield sklearn.utils.shuffle(X_train, y_train)

# compile and train the model using the generator function
train_generator = generator(train_samples, batch_size=int(len(train_samples)/32))
validation_generator = generator(validation_samples, batch_size=int(len(validation_samples)/32))

ch, row, col = 3, 160, 320  # Trimmed image format

from keras.models import Sequential
from keras.layers import Lambda


model = Sequential()
# Preprocess incoming data, centered around zero with small standard deviation 
model.add(Lambda(lambda x: x/127.5 - 1.,
        input_shape=(row, col, ch),
        output_shape=(row, col, ch)))
#model.add(... finish defining the rest of your model architecture here ...)

model.compile(loss='mse', optimizer='adam')

model.fit_generator(train_generator,
                    steps_per_epoch=len(train_samples) / 32, validation_data=validation_generator,
                    validation_steps=len(validation_samples)/32, epochs=3)

我有什么想法可以解决这个问题吗?

1 个答案:

答案 0 :(得分:0)

这是加载图片时出错,名称未正确定义。由于cv2.imread(name)没有引发错误,因为它没有找到图像,只返回一个None对象,该方法返回一个空变量,导致网络上的错误。