让我给您一个例子,说明我正在尝试使用Iris数据集实现的目标。我不仅要预测花朵的类别,还想知道哪种花朵最相似。
在数据集中,我为从0到149开始的每朵花添加了索引。我希望从输出中获取预测的类以及该类最接近成员的索引。
这是我尝试使用的模型see picture
这是我使用的代码:
加载数据
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import np_utils
from keras.utils import to_categorical
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
iris = load_iris()
data = iris.data
target = iris.target.reshape(150,1)
indexes = np.arange(target.shape[0]).reshape(150,1)
table = np.concatenate([indexes,iris.target.reshape(150,1)], axis = 1)
table = pd.DataFrame(table, columns=['index', 'real_class'])
y_index = to_categorical( indexes )
target = to_categorical( target )
分为训练和测试数据
X_train, X_test, Y_train, Y_test, Y_train_index, Y_test_index = train_test_split( data, target, y_index, test_size=0.5, shuffle = True)
创建Keras模型
main_input = Input(shape=(data.shape[1],), name='main_input')
x = Dense(12, activation='relu')(main_input)
x = Dense(6, activation='relu')(x)
main_output = Dense(target.shape[1], activation='softmax', name='main_output')(x)
auxiliary_output = Dense(Y_train_index.shape[1], activation='softmax', name='index_of_claim')(x)
model = Model(inputs=main_input, outputs=[main_output,auxiliary_output])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])
model.fit(X_train, [Y_train,Y_train_index], validation_data=(X_test, [Y_test,Y_test_index]), epochs=100, batch_size=50)
预测类和索引
proba,proba_index = model.predict([X_test])
predictions = np.concatenate([np.argmax(proba_index, axis = 1).reshape(75,1), np.argmax(proba, axis = 1).reshape(75,1)], axis = 1)
predictions = pd.DataFrame(predictions, columns=['index', 'predicted_class'])
fact_table = pd.merge(predictions, table, on='index', how='left')
检查索引的类是否对应于真实的类
fact_table['predicted_class'].equals(fact_table['real_class'])
这总是给出 False
如何更改模型,以便始终获得预测类以匹配预测索引的类?