我一直遵循以下guide来制作自己的自定义线程安全生成器。 图片生成器的灵感来自this。
这是图像生成器的代码
class ImageSequence(Sequence):
def __init__(self, dataset_csv_file, class_names, source_image_dir, batch_size=16,
target_size=(224, 224), augmenter=None, verbose=0, steps=None,
shuffle_on_epoch_end=True, random_state=1):
"""
:param dataset_csv_file: str, path of dataset csv file
:param class_names: list of str
:param batch_size: int
:param target_size: tuple(int, int)
:param verbose: int
"""
self.dataset_df = pd.read_csv(dataset_csv_file)
self.source_image_dir = source_image_dir
self.batch_size = batch_size
self.target_size = target_size
self.augmenter = augmenter
self.verbose = verbose
self.shuffle = shuffle_on_epoch_end
self.random_state = random_state
self.class_names = class_names
self.prepare_dataset()
if steps is None:
self.steps = int(np.ceil(len(self.x_path) / float(self.batch_size)))
else:
self.steps = int(steps)
def __bool__(self):
return True
def __len__(self):
return self.steps
def __getitem__(self, idx):
batch_x_path = self.x_path[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_x = np.asarray([self.load_image(x_path) for x_path in batch_x_path])
batch_x = self.transform_batch_images(batch_x)
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return (batch_x, batch_y)
def load_image(self, image_file):
image_path = os.path.join(self.source_image_dir, image_file)
image = Image.open(image_path)
image_array = np.asarray(image.convert("RGB"))
image_array = image_array / 255.
#image_array = resize(image_array, self.target_size)
return image_array
def transform_batch_images(self, batch_x):
if self.augmenter is not None:
batch_x = self.augmenter.augment_images(batch_x)
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
batch_x = (batch_x - imagenet_mean) / imagenet_std
return batch_x
def get_y_true(self):
return self.y[:self.steps*self.batch_size, :]
def prepare_dataset(self):
df = self.dataset_df.sample(frac=1., random_state=self.random_state)
self.x_path, self.y = df["Image Index"].as_matrix(), df[self.class_names].as_matrix()
def on_epoch_end(self):
if self.shuffle:
self.random_state += 1
self.prepare_dataset()
当我使用fit_generator方法在TPU中编译此代码时,会收到错误消息
'ImageSequence' object has no attribute 'shape'
生成器工作正常。如果我打电话给发电机,它会发出连声。 任何帮助将非常感激。预先感谢。