使用fit_generator()时奇怪的输入形状

时间:2018-08-15 08:47:17

标签: python keras generator

我正在使用此生成器生成数据

def batch_no_aug(file_names,batch_size,img_size_target=224) : 
    idx = 0
    images = []
    masks = []
    while True : 
        for i in range(len(file_names)) : 
            img = np.array(load_img('../input/images/{}.png'.format(file_names[i],grayscale=True)))/255
            msk = np.array(load_img('../input/masks/{}.png'.format(file_names[i],grayscale=True)))/255
            msk = 0.2989*msk[:,:,0] + 0.5870*msk[:,:,1] + 0.1140*msk[:,:,2]
            img = resample(img,img_size_target = img_size_target)
            msk = resample(msk,img_size_target = img_size_target)
            msk = msk.reshape((img_size_target,img_size_target,1))

            images.append(img)
            masks.append(msk)
            idx += 1

            if idx>=batch_size : 
                npimgs = np.array(images)
                npmsks = np.array(masks)
                images = []
                masks = []
                idx=0
                print(' No aug generator image output shape : ',npimgs.shape)
                print(' No aug generator mask output shape : ',npmsks.shape)
                yield npimgs, npmsks
        print(' No aug generator image output shape : ',npimgs.shape)
        print(' No aug generator mask output shape : ',npmsks.shape)
        yield np.array(images),np.array(masks)

这是我如何使用fit_generator()

ids,holdout_idx = train_test_split(train_df)
ids.reset_index(inplace=True)
holdout_idx.reset_index(inplace=True)

for train_idx,val_idx in kf.split(ids) : 
    for lr,ep in zip(lrs,eps) : 
        model.compile(optimizer=Adam(lr=lr),loss = dice_loss,metrics=[dice_coef])
        model.fit_generator(batch_no_aug(ids['id'].iloc[train_idx].values,batch_size=batch),
            steps_per_epoch=np.ceil(len(train_idx)/batch),epochs=ep,callbacks=callbacks,
            validation_data=batch_no_aug(ids['id'].iloc[val_idx].values,batch_size=batch),
            validation_steps=np.ceil(len(val_idx)/batch))

运行一个纪元后,我总是以这个错误结束

ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (0, 1)

在生成数据之前,我没有看到任何生成器打印出的形状为(0,1)。但是错误就在那里。

0 个答案:

没有答案