从tensorflow中的csv文件预测值

时间:2018-08-06 19:18:37

标签: python tensorflow

我已经创建了线性分类模型,并将其检查点存储在硬盘上。

在另一个python文件中,我希望恢复模型,并使用该模型预测CSV文件中存储的数据的值。我有以下代码:

import tensorflow as tf
import pandas as pd

# file with data that must be predicted
file = 'census_predict.csv'
x_test = pd.read_csv(file)

print(x_test)

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "models/1/1.ckpt")
    print("Model restored.")


    pred_fn = tf.estimator.inputs.pandas_input_fn(x=x_test, batch_size=len(x_test), shuffle=False)
    predictions = list(sess.predict(input_fn=pred_fn))
    print(predictions)

# close TF session in case a new one must be created later on
tf.reset_default_graph()
sess.close()

现在我遇到以下错误:

AttributeError: 'Session' object has no attribute 'predict'

我一直在研究stackoverflow和其他网站,但似乎无法弄清楚如何正确使用它。谁能指导我正确的方向?

0 个答案:

没有答案