将class_weight添加到.fit_generator()会中断to_categorical()

时间:2019-09-02 19:12:25

标签: python tensorflow keras

尝试使用DataGenerator类训练带有一堆图像的CNN,模型通常可以正常工作。问题是训练数据集非常偏向几个类别,因此我想添加class_weights。但是,每次执行此操作时,我都会在将我标记的类转换为单数组的代码部分中出现索引错误。

如果Keras在tensorflow顶部运行,则为

。出现问题的函数是keras.utils.to_categorical()

这是分类函数:

for i, pdb_id in enumerate(list_enzymes_temp):
    mat = precomputed_distance_matrix(pdb_id, self.dim)

    X[i,] = mat.distance_matrix.reshape(*self.dim)

    y[i] = int(self.labels[pdb_id.upper()][1]) - 1

    return X, keras.utils.to_categorical(y, num_classes=self.n_classes)

这是我用来生成权重的函数

def get_class_weights(dictionary, training_enzymes, mode):
    'Gets class weights for Keras'
    # Initialization
    counter = [0 for i in range(6)]

    # Count classes
    for enzyme in training_enzymes:
        counter[int(dictionary[enzyme.upper()][1])-1] += 1
    majority = max(counter)

    # Make dictionary
    class_weights = {i: float(majority/count) for i, count in enumerate(counter)}

    # Value according to mode
    if mode == 'unbalanced':
        for key in class_weights:
            class_weights[key] = 1
    elif mode == 'balanced':
        pass
    elif mode == 'mean_1_balanced':
        for key in class_weights:
            class_weights[key] = (1+class_weights[key])/2

    return class_weights

和我的fit_generator函数:

model.fit_generator(generator=training_generator,
                validation_data=validation_generator,
                epochs=max_epochs,
                max_queue_size=16,
                class_weight=class_weights,
                callbacks=[tensorboard])

此处没有出现IndexError消息,并且在没有添加class_weights的情况下,模型可以完美运行:

File "C:\Users\Python\DMCNN\data_generator.py", line 73, in __getitem__
X, y = self.__data_generation(list_enzymes_temp)
File "C:\Users\Python\DMCNN\data_generator.py", line 59, in __data_generation
return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
File "C:\Users\Python\Anaconda3\lib\site-packages\keras\utils\np_utils.py", line 34, in to_categorical
categorical[np.arange(n), y] = 1
IndexError: index 1065353216 is out of bounds for axis 1 with size 6

1 个答案:

答案 0 :(得分:0)

在使用keras.utils.to_categorical时,我遇到了相同的错误。我得到的错误是“ IndexError:索引1065353216超出了尺寸2的轴1的范围”,因为我有2个类。

我认为它是从1.0转换为1.0f(32位浮点数)的原因,因为1065353216是32位浮点值1.0的无符号32位整数表示形式(请在此处:Why is 1.0f in C code represented as 1065353216 in the generated assembly?)。在我的情况下,并非所有批次都具有相同的长度,最后以X和y填充一些空白,这导致了问题。您可以事先检查W(甚至X和Y)中是否有未填写的元素。您还可以看到keras.utils.to_categorical具有默认值dtype ='float32'。您可以尝试指定dtype,例如“在您的情况下,返回X,keras.utils.to_categorical(y,num_classes = self.n_classes,dtype ='uint8')”,以查看是否可行。