我想预测一个双链多类,模型的得分很好,但是我不知道如何预测标签。
我遵循本教程:https://stackoverflow.blog/2019/05/06/predicting-stack-overflow-tags-with-googles-cloud-ai/
我使用了model.predict(xtest),但结果始终相同
X1 = data['Body']
#data['Body'].head()
# absolutely positioned div containing several c...
# given datetime representing person birthday ca...
# given specific datetime value display relative...
# expose linq query asmx web service usually bus...
# store binary data mysql
train_size = int(len(data) * .8)
train_qs = data['Body'].values[:train_size]
test_qs = data['Body'].values[train_size:]
from tensorflow.keras.preprocessing import text
tokenizer = text.Tokenizer(num_words=400)
tokenizer.fit_on_texts(train_qs)
bag_of_words_train = tokenizer.texts_to_matrix(train_qs)
bag_of_words_test = tokenizer.texts_to_matrix(test_qs)
VOCAB_SIZE=400
tags_split =data['Tags']
#data['Tags'].head()
# [html, css]
# [c#, .net, datetime]
# [c#, datetime]
# [c#, linq, web-services, .net-3.5]
# [mysql, database]
# Create the encoder
from sklearn.preprocessing import MultiLabelBinarizer
tag_encoder = MultiLabelBinarizer()
tags_encoded = tag_encoder.fit_transform(tags_split)
# Split the tags into train/test
train_tags = tags_encoded[:train_size]
test_tags = tags_encoded[train_size:]
from tensorflow.keras.preprocessing import text
import keras
import tensorflow as tf
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(50, input_shape=(VOCAB_SIZE,), activation='relu'))
model.add(tf.keras.layers.Dense(25, activation='relu'))
model.add(tf.keras.layers.Dense(100, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[keras.metrics.binary_accuracy])
model.fit(bag_of_words_train, train_tags, epochs=3, batch_size=128, validation_split=0.1)
model.evaluate(bag_of_words_test, test_tags, batch_size=128)
所以我不明白输出:
yhat = model.predict_classes(x=bag_of_words_test, batch_size=128)
print(yhat)
#[10 10 10 ... 10 10 10]
#there are only 47 prediction diferent than 10 so 1848 of 10predict
#but there are only 319 of this label in ytest
tag_encoder.classes_[10]
#'c#'