尝试使用DataGenerator类训练带有一堆图像的CNN,模型通常可以正常工作。问题是训练数据集非常偏向几个类别,因此我想添加class_weights。但是,每次执行此操作时,我都会在将我标记的类转换为单数组的代码部分中出现索引错误。
如果Keras在tensorflow顶部运行,则为
。出现问题的函数是keras.utils.to_categorical()
这是分类函数:
for i, pdb_id in enumerate(list_enzymes_temp):
mat = precomputed_distance_matrix(pdb_id, self.dim)
X[i,] = mat.distance_matrix.reshape(*self.dim)
y[i] = int(self.labels[pdb_id.upper()][1]) - 1
return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
这是我用来生成权重的函数
def get_class_weights(dictionary, training_enzymes, mode):
'Gets class weights for Keras'
# Initialization
counter = [0 for i in range(6)]
# Count classes
for enzyme in training_enzymes:
counter[int(dictionary[enzyme.upper()][1])-1] += 1
majority = max(counter)
# Make dictionary
class_weights = {i: float(majority/count) for i, count in enumerate(counter)}
# Value according to mode
if mode == 'unbalanced':
for key in class_weights:
class_weights[key] = 1
elif mode == 'balanced':
pass
elif mode == 'mean_1_balanced':
for key in class_weights:
class_weights[key] = (1+class_weights[key])/2
return class_weights
和我的fit_generator函数:
model.fit_generator(generator=training_generator,
validation_data=validation_generator,
epochs=max_epochs,
max_queue_size=16,
class_weight=class_weights,
callbacks=[tensorboard])
此处没有出现IndexError消息,并且在没有添加class_weights的情况下,模型可以完美运行:
File "C:\Users\Python\DMCNN\data_generator.py", line 73, in __getitem__
X, y = self.__data_generation(list_enzymes_temp)
File "C:\Users\Python\DMCNN\data_generator.py", line 59, in __data_generation
return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
File "C:\Users\Python\Anaconda3\lib\site-packages\keras\utils\np_utils.py", line 34, in to_categorical
categorical[np.arange(n), y] = 1
IndexError: index 1065353216 is out of bounds for axis 1 with size 6
答案 0 :(得分:0)
在使用keras.utils.to_categorical时,我遇到了相同的错误。我得到的错误是“ IndexError:索引1065353216超出了尺寸2的轴1的范围”,因为我有2个类。
我认为它是从1.0转换为1.0f(32位浮点数)的原因,因为1065353216是32位浮点值1.0的无符号32位整数表示形式(请在此处:Why is 1.0f in C code represented as 1065353216 in the generated assembly?)。在我的情况下,并非所有批次都具有相同的长度,最后以X和y填充一些空白,这导致了问题。您可以事先检查W(甚至X和Y)中是否有未填写的元素。您还可以看到keras.utils.to_categorical具有默认值dtype ='float32'。您可以尝试指定dtype,例如“在您的情况下,返回X,keras.utils.to_categorical(y,num_classes = self.n_classes,dtype ='uint8')”,以查看是否可行。