来自数据生成器的随机裁剪图像

时间:2018-06-21 11:10:17

标签: python tensorflow keras computer-vision crop

我发现此代码用于对张量流后端进行图像的随机裁剪。

def random_crop(img, random_crop_size)
    # Note: image_data_format is 'channel_last'
    assert img.shape[2] == 3
    height, width = img.shape[0], img.shape[2]
    dy, dx = random_crop_size
    x = np.random.randint(0, width - dx + 1)
    y = np.random.randint(0, height - dy + 1)
    return img[y:(y+dy), x:(x+dx), :]
def crop_generator(batches, crop_length):
    while True:
        batch_x, batch_y = next(batches)
        batch_crops = np.zeros((batch_x.shape[0], crop_length, crop_length, 3))
        for i in range(batch_x.shape[0]):
            batch_crops[i] = random_crop(batch_x[i], (crop_length, crop_length))
        yield (batch_crops, batch_y)

就我而言,我想将形状为(256,256,3)的图像裁剪为(224,224,3)。在imagenes_trigo中,每个项目都是一个包含图像路径的字符串。我已经使用class PantasTrigoDataset生成了数据。如何使用上面的图像代码裁剪data_traindata_valid的每个图像?如果有人帮助我,那就太好了。谢谢。

x_train,x_test,y_train,y_test = train_test_split(imagenes_trigo,labels,test_size=0.25)

class PlantasTrigoDataset(Sequence):
    def __init__(self,imagenes_trigo,labels):
        self.x=[imread(i) for i in tqdm(imagenes_trigo)]
        self.y=labels
        self.batch_size=8
    def __len__(self):
        return int(np.ceil(len(self.x)/float(self.batch_size)))
    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        return np.array([file_name for file_name in batch_x]), np.array(batch_y)
    def get_steps(self):
        return len(self.x)//self.batch_size

data_train = PlantasTrigoDataset(x_train,y_train)
data_val = PlantasTrigoDataset(x_test,y_test)

我已经尝试过了:

train_crops = crop_generator(data_train, crop_length)
valid_crops = crop_generator(data_val, crop_length)

但是当我拟合这样的模型时:

model.fit_generator(
                train_crops,
                steps_per_epoch= 2245,
                epochs=500,
                validation_data=valid_crops,
                validation_steps=748,
                callbacks=[csv_logger,checkpoint])

我收到此错误:

File "/home/jokin/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 658, in _data_generator_task
generator_output = next(self._generator)

File "/home/jokin/PycharmProjects/TFG/PLANTS DISEASE/plants.py", line 106, in crop_generator
batch_x, batch_y = next(batches)
TypeError: 'PlantasTrigoDataset' object is not an iterator

0 个答案:

没有答案