ImageDataGenerator-预处理X_train

时间:2020-06-04 18:09:33

标签: python tensorflow keras conv-neural-network generator

我有两个数据集。第一个包含图像数据路径,因此是我的输入X_train的路径。第二个数据集包含标签,它们是一种经过热编码的格式,并且格式特殊,其形状是3维的(图像数量,标签的长度,字符可能性),即我的数据集的(n, 8, 36)。标签是y_train数据。

标签的形状是为什么我要寻找一种方法来成批读取X_train数据并与y_train数据分开进行预处理的原因。是否有这种方法,或者您知道如何解决此问题?

非常感谢!

1 个答案:

答案 0 :(得分:1)

您可以通过创建从Sequence class继承的类来使用自定义keras生成器。

另一个[具有更多详细信息的答案[[Clarification about keras.utils.Sequence

这是一个例子

class Custom_Generator(keras.utils.Sequence) :
    def __init__(self,...,datapath, batch_size, ..) :

    def __len__(self) :
        #calculate data len, something like len(train_labels)


    def load_and_preprocess_function(self, label_names, ...):
        #do something...
        #load data for the batch using label names with whatever library

    def __getitem__(self, idx) :
        batch_y = train_labels[idx:idx+batch_size]
        batch_x = self.load_and_preprocess_function()
        return ( batch_x, batch_y )