Keras:如何获取预测类的索引(即预测类的最接近成员的索引)

时间:2019-05-23 10:20:30

标签: python tensorflow keras

让我给您一个例子,说明我正在尝试使用Iris数据集实现的目标。我不仅要预测花朵的类别,还想知道哪种花朵最相似。

在数据集中,我为从0到149开始的每朵花添加了索引。我希望从输出中获取预测的类以及该类最接近成员的索引。

这是我尝试使用的模型see picture

这是我使用的代码:

加载数据

from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.utils import to_categorical
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

iris = load_iris()

data = iris.data
target = iris.target.reshape(150,1)

indexes = np.arange(target.shape[0]).reshape(150,1)

table = np.concatenate([indexes,iris.target.reshape(150,1)], axis = 1)
table = pd.DataFrame(table, columns=['index', 'real_class'])

y_index = to_categorical( indexes )

target = to_categorical( target )

分为训练和测试数据

X_train, X_test, Y_train, Y_test, Y_train_index, Y_test_index = train_test_split( data, target, y_index, test_size=0.5, shuffle = True)

创建Keras模型

main_input = Input(shape=(data.shape[1],), name='main_input')

x = Dense(12, activation='relu')(main_input)
x = Dense(6, activation='relu')(x)

main_output = Dense(target.shape[1], activation='softmax', name='main_output')(x)
auxiliary_output = Dense(Y_train_index.shape[1], activation='softmax', name='index_of_claim')(x)

model = Model(inputs=main_input, outputs=[main_output,auxiliary_output])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])

model.fit(X_train, [Y_train,Y_train_index], validation_data=(X_test, [Y_test,Y_test_index]), epochs=100, batch_size=50)

预测类和索引

proba,proba_index = model.predict([X_test])

predictions = np.concatenate([np.argmax(proba_index, axis = 1).reshape(75,1), np.argmax(proba, axis = 1).reshape(75,1)], axis = 1)
predictions = pd.DataFrame(predictions, columns=['index', 'predicted_class'])

fact_table = pd.merge(predictions, table, on='index', how='left')

检查索引的类是否对应于真实的类

fact_table['predicted_class'].equals(fact_table['real_class'])

这总是给出 False

如何更改模型,以便始终获得预测类以匹配预测索引的类?

0 个答案:

没有答案