我已经创建了线性分类模型,并将其检查点存储在硬盘上。
在另一个python文件中,我希望恢复模型,并使用该模型预测CSV文件中存储的数据的值。我有以下代码:
import tensorflow as tf
import pandas as pd
# file with data that must be predicted
file = 'census_predict.csv'
x_test = pd.read_csv(file)
print(x_test)
# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "models/1/1.ckpt")
print("Model restored.")
pred_fn = tf.estimator.inputs.pandas_input_fn(x=x_test, batch_size=len(x_test), shuffle=False)
predictions = list(sess.predict(input_fn=pred_fn))
print(predictions)
# close TF session in case a new one must be created later on
tf.reset_default_graph()
sess.close()
现在我遇到以下错误:
AttributeError: 'Session' object has no attribute 'predict'
我一直在研究stackoverflow和其他网站,但似乎无法弄清楚如何正确使用它。谁能指导我正确的方向?