混淆矩阵不会显示CNN模型

时间:2019-11-19 02:02:55

标签: python

混淆矩阵不会显示在pycharm上,我需要一堆混淆矩阵,但我无法使其显示
以退出代码0完成的过程是之后的最后一个代码 227/227 [==============================]-8s 36ms / step-损耗:0.6156-精度:0.7225-val_loss :0.6402-val_accuracy:0.6154 (我只是想在这里写点东西,因为它说我需要添加更多单词,而无视ty)

import tensorflow as tf
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2

DATADIR = "C:/Users/Acer/imagerec/MRI"

CATEGORIES = ["yes", "no"]

for category in CATEGORIES:
    path = os.path.join(DATADIR,category)
    for img in os.listdir(path):
        img_array = cv2.imread(os.path.join(path,img) ,cv2.IMREAD_GRAYSCALE)
        plt.imshow(img_array, cmap='gray')
        plt.show()

        break
    break
print(img_array)
print(img_array.shape)

IMG_SIZE = 50

new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
plt.imshow(new_array, cmap='gray')
plt.show()

training_data = []

def create_training_data():
    for category in CATEGORIES:
        path = os.path.join(DATADIR, category)
        class_num = CATEGORIES.index(category)
        for img in os.listdir(path):
            try:
                img_array = cv2.imread(os.path.join(path, img), cv2.IMREAD_GRAYSCALE)
                new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))
                training_data.append([new_array, class_num])
            except Exception as e:
                pass


create_training_data()

print(len(training_data))

import random

random.shuffle(training_data)
for sample in training_data[:10]:
    print(sample[1])

X = []
y = []
for features, label in training_data:
    X.append(features)
    y.append(label)

X = np.array(X).reshape(-1, IMG_SIZE, IMG_SIZE, 1)

import pickle

pickle_in = open("X.pickle","rb")
X = pickle.load(pickle_in)



pickle_in = open("y.pickle","rb")
y = pickle.load(pickle_in)

X = X/255.0

model = Sequential()

model.add(Conv2D(256, (3, 3), input_shape=X.shape[1:]))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(256, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())

model.add(Dense(64))
model.add(Activation('relu'))

model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
model.fit(X, y, batch_size=5, epochs=1, validation_split=0.1)

from sklearn.metrics import confusion_matrix
pred = model.predict(X)
pred = np.round(pred)

conf = confusion_matrix(y, pred)

import seaborn as sns
sns.heatmap(conf, annot=True)

model.save('64x2-CNN.model')

1 个答案:

答案 0 :(得分:0)

您的预测值需要匹配您的目标。打印出来,看看是否是浮点数,其他整数。如果是这样,请在 Dim ctr As Integer = 1 Dim row As Integer = 7 Dim xl As New Excel.Application Dim wkb As Excel.Workbook Dim wks As Excel.Worksheet Dim cbd As Excel.Range = Nothing xl = CreateObject("Excel.Application") wkb = xl.Workbooks.Add(Application.StartupPath & "\report\atty_report.xlsx") wks = wkb.Worksheets.Item("sheet1") Using con As SqlConnection = New SqlConnection("Data Source=DESKTOP-4LACMEO\MSSQLSERVER2017;Initial Catalog=IBP_CMS_Database;User ID=sa;Password=IbpCavite") con.Open() Dim command As SqlClient.SqlCommand = con.CreateCommand() If Me.cmbsearchby.SelectedIndex < 0 Then command.CommandText = "SELECT Name, Address, Contact_Number, Office_Address, Office_Contact_Number, Email_Address, Roll_Number from tblAttyInfo order by Name" objdatareader = command.ExecuteReader wks.Cells(4, 1) = "List of Lawyer's" End If While objdatareader.Read wks.Cells(row, 1) = objdatareader("Name").ToString() wks.Cells(row, 2) = objdatareader("Address").ToString() wks.Cells(row, 3) = objdatareader("Contact_Number").ToString() wks.Cells(row, 4) = objdatareader("Office_Address").ToString() wks.Cells(row, 5) = objdatareader("Office_Contact_Number").ToString() wks.Cells(row, 6) = objdatareader("Email_Address").ToString() wks.Cells(row, 7) = objdatareader("Roll_Number").ToString() wks.Cells.Borders(Excel.XlBordersIndex.xlEdgeBottom).LineStyle = Excel.XlLineStyle.xlContinuous row += 1 End While con.Close() End Using xl.Visible = True wks.PrintPreview() wks.Protect() wks.Visible = True xl.Visible = False xl.Quit() 之后执行此操作:

pred = model.predict(X)

或任何其他将您的预测值转换为pred = np.eye(2)[np.argmax(pred, axis=1)] 0的代码。

如果您的目标(1)是字符串列表:

y
相关问题