我发现此代码用于对张量流后端进行图像的随机裁剪。
def random_crop(img, random_crop_size)
# Note: image_data_format is 'channel_last'
assert img.shape[2] == 3
height, width = img.shape[0], img.shape[2]
dy, dx = random_crop_size
x = np.random.randint(0, width - dx + 1)
y = np.random.randint(0, height - dy + 1)
return img[y:(y+dy), x:(x+dx), :]
def crop_generator(batches, crop_length):
while True:
batch_x, batch_y = next(batches)
batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
for i in range(batch_x.shape[0]):
batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
yield (batch_crops, batch_y)
就我而言,我想将形状为(256,256,3)
的图像裁剪为(224,224,3)
。在imagenes_trigo
中,每个项目都是一个包含图像路径的字符串。我已经使用class PantasTrigoDataset
生成了数据。如何使用上面的图像代码裁剪data_train
和data_valid
的每个图像?如果有人帮助我,那就太好了。谢谢。
x_train,x_test,y_train,y_test = train_test_split(imagenes_trigo,labels,test_size=0.25)
class PlantasTrigoDataset(Sequence):
def __init__(self,imagenes_trigo,labels):
self.x=[imread(i) for i in tqdm(imagenes_trigo)]
self.y=labels
self.batch_size=8
def __len__(self):
return int(np.ceil(len(self.x)/float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([file_name for file_name in batch_x]), np.array(batch_y)
def get_steps(self):
return len(self.x)//self.batch_size
data_train = PlantasTrigoDataset(x_train,y_train)
data_val = PlantasTrigoDataset(x_test,y_test)
我已经尝试过了:
train_crops = crop_generator(data_train, crop_length)
valid_crops = crop_generator(data_val, crop_length)
但是当我拟合这样的模型时:
model.fit_generator(
train_crops,
steps_per_epoch= 2245,
epochs=500,
validation_data=valid_crops,
validation_steps=748,
callbacks=[csv_logger,checkpoint])
我收到此错误:
File "/home/jokin/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 658, in _data_generator_task
generator_output = next(self._generator)
File "/home/jokin/PycharmProjects/TFG/PLANTS DISEASE/plants.py", line 106, in crop_generator
batch_x, batch_y = next(batches)
TypeError: 'PlantasTrigoDataset' object is not an iterator