Keras生成器和fit_generator,如何构建生成器以避免“函数形状”错误

时间:2019-03-29 07:26:26

标签: python tensorflow keras

我正在为Keras构建一个生成器,以便能够加载我的数据集图像,因为它对我的ram来说有点大。

我这样构建了生成器:

# import the necessary packages
import tensorflow
from tensorflow import keras
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import pandas as pd
from tqdm import tqdm

#loading
path_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- 
images_improved.txt"
df = pd.read_csv(path_to_txt ,sep='\t')
arr = np.array(df)
#epochs and steps:
NUM_TRAIN_IMAGES = 0
NUM_EPOCHS = 30

def image_generator(arr, bs, mode="train", aug=None):
  while True:
    images = []
    labels = []
    for row in arr:
      if len(images) < bs:
        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" + 
        row[0]),(224,224)))
        images.append(img)
        labels.append([row[2]])
        NUM_TRAIN_IMAGES += 1
      else:
        break


  if aug is not None:
    (images, labels) = next(aug.flow(np.array(images),labels, 
     batch_size=bs))

  obj = OneHotEncoder()
  values = obj.fit_transform(labels).toarray()

  yield (np.array(images), labels)

然后我从顺序模型中调用fit_generator(cnn起作用直到出现OOM错误)

#create the augmentation function:
 aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
    horizontal_flip=True, fill_mode="nearest")

#create the generator:
gen = image_generator(arr, bs = 32, mode = "train", aug = aug)

history = model.fit_generator(image_generator,
    steps_per_epoch = NUM_TRAIN_IMAGES,
    epochs = NUM_EPOCHS)

从这里,我得到这个错误:

# Create generator from NumPy or EagerTensor Input.
--> 377   num_samples = int(nest.flatten(data)[0].shape[0])
378   if batch_size is None:
379     raise ValueError('You must specify `batch_size`')
AttributeError: 'function' object has no attribute 'shape'

1 个答案:

答案 0 :(得分:1)

我在这里看到两个主要错误。

首先,生成器功能的内存使用效率不高。因为您首先加载所有图像(while循环)。您应该遍历图像文件,并在循环中生成带有标签的图像np.array。

第二,当您应该使用生成器函数的返回对象gen时,将生成器函数名称传递给fit_generator。