TensorFlow中用于随机森林分类的​​模型预测

时间:2020-05-14 13:48:02

标签: tensorflow machine-learning model random-forest prediction

我在随机森林中训练了模型,并获得了90%左右的良好准确性。训练模型后,我想通过提供单个输入进行预测,目前我陷入其中。 我从link

获得了代码参考
X_train, X_test, Y_train, Y_test = train_test_split(x, y, test_size=0.2, random_state=1)

inputData = X_train[1]
single_voicedata2d = np.reshape(inputData,(1,totalFeatures))
print(inputData)
# Parameters

num_steps = 500 # Total steps to train
num_classes = 2
num_features = 20
num_trees = 10
max_nodes = 1000 

# Input and Target placeholders 

X = tf.placeholder(tf.float32, shape=[None, num_features],name="myInput")
Y = tf.placeholder(tf.int32, shape=[None],name="myOutput")

# Random Forest Parameters

hparams = tensor_forest.ForestHParams(num_classes=num_classes, num_features=num_features, num_trees=num_trees, max_nodes=max_nodes).fill()

# Build the Random Forest

forest_graph = tensor_forest.RandomForestGraphs(hparams)

# Get training graph and loss

train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)

# Measure the accuracy

infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64),name="predictions")
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32),name="accuracy")

# Initialize the variables (i.e. assign their default value) and forest resources
init_vars = tf.group(tf.global_variables_initializer(), resources.initialize_resources(resources.shared_resources()))

# Start TensorFlow session
sess = tf.Session()

# Run the initializer

sess.run(init_vars)

# Training
saver = tf.train.Saver()
for i in range(1, num_steps + 1):
   _, l = sess.run([train_op, loss_op], feed_dict={X: X_train, Y: Y_train})
   if i % 50 == 0 or i == 1:
       acc = sess.run(accuracy_op, feed_dict={X: X_train, Y: Y_train})
       print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

#save model
saver.save(sess,'new_models/my_second_model.ckpt')
# Test Model
print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: X_test, Y: Y_test}))
print("Prediction:", sess.run(correct_prediction, feed_dict={X: single_voicedata2d})) 

我不知道预测的语法是什么,我想在显示错误的行下方提供预测的输入。

 print("Prediction:", sess.run(correct_prediction, feed_dict={X: single_voicedata2d}))

0 个答案:

没有答案