多类混淆矩阵Keras工具

时间:2020-08-21 21:56:17

标签: python keras scikit-learn neural-network multilabel-classification

我正在做一个旨在对10种不同化合物进行分类的神经网络,数据集类似于:

array([[400.  ,  23.  ,  52.38, ...,   1.  ,   0.  ,   0.  ],
   [400.  ,  21.63,  61.61, ...,   0.  ,   0.  ,   0.  ],
   [400.  ,  21.49,  61.95, ...,   0.  ,   0.  ,   0.  ],
   ...,
   [400.  ,  21.69,  41.98, ...,   0.  ,   0.  ,   0.  ],
   [400.  ,  22.48,  65.2 , ...,   0.  ,   0.  ,   0.  ],
   [400.  ,  22.02,  58.91, ...,   0.  ,   0.  ,   1.  ]])

最后10个数字是我要识别的化合物的一个热门编码。这是我正在使用的代码:

dataset=numpy.asfarray(dataset[1:,0:],float)
x = dataset[0:,0:30]
y = dataset[0:,30:40]

x_train, x_test, y_train, y_test = train_test_split(    
x, y, test_size=0.20, random_state=1)   #siempre ha sido 42


standard=preprocessing.StandardScaler().fit(x_train)
x_train=standard.transform(x_train)
x_test=standard.transform(x_test)
dump(standard, 'std_modelo_400.bin', compress=True)

model = Sequential()
model.add(Dense(50, input_dim = x_test.shape[1], activation  =  'relu',kernel_regularizer=keras.regularizers.l1(0.01)))
model.add(Dense(30, input_dim = x_test.shape[1], activation  = 'relu',kernel_regularizer=keras.regularizers.l1(0.01)))
model.add(Dense(15, input_dim = x_test.shape[1], activation  = 'relu',kernel_regularizer=keras.regularizers.l1(0.01)))
model.add(Dense(10, activation='softmax',kernel_initializer='normal', bias_initializer=keras.initializers.Constant(value=0)))


model.summary()


model.compile(loss='categorical_crossentropy', 
          optimizer='adam',  
          metrics=['accuracy']
         )

history=model.fit(x_train,y_train,validation_data=(x_test,y_test),verbose=2,epochs=epochs,batch_size=batch_size)#callbacks=[monitor] , verbose=2

我尝试使用命令multilabel_confusion_matrix(y_test,pred)来获得混淆矩阵,并且得到以下形式:

array([[[929681,    158],
    [   308, 102180]],

   [[930346,    407],
    [  6677,  94897]],

   [[930740,     38],
    [   477, 101072]],

   [[929287,   1522],
    [    69, 101449]],

   [[929703,   8843],
    [ 12217,  81564]],

   [[902624,    474],
    [  1565, 127664]],

   [[931152,   2236],
    [ 12140,  86799]],

   [[929085,     10],
    [     0, 103232]],

   [[911158,  22378],
    [  5362,  93429]],

   [[930412,    689],
    [   617, 100609]]], dtype=int64)

使用multilabel_confusion_matrix(y_test,pred,labels=["Comp1","Comp2","Comp3", "Comp4", "Comp5", "Comp6", "Comp7", "Comp8", "Comp9", "Comp10",])时出现错误:

elementwise comparison failed; returning scalar instead, but in the future will perform     elementwise comparison
mask &= (ar1 != a)
Traceback (most recent call last):

File "<ipython-input-18-00af06ffcbef>", line 1, in <module>
multilabel_confusion_matrix(y_test,pred,labels=["Comp1","Comp2","Comp3", "Comp4", "Comp5", "Comp6", "Comp7", "Comp8", "Comp9", "Comp10",])

File "C:\Users\fmarin\Anaconda3\lib\site-packages\sklearn\metrics\_classification.py", line 485, in multilabel_confusion_matrix
if np.max(labels) > np.max(present_labels):

我不知道如何解决它。我还想获得混淆矩阵的图形版本,我正在使用scikit-learn工具箱。

谢谢!

0 个答案:

没有答案