我使用人口普查数据集一直在玩Tensorflow Wide and Deep tutorial。
线性/广泛教程陈述:
We will train a logistic regression model, and given an individual's information our model will output a number between 0 and 1
目前,我无法弄清楚如何预测单个输入的输出(从单元测试中复制):
TEST_INPUT_VALUES = {
'age': 18,
'education_num': 12,
'capital_gain': 34,
'capital_loss': 56,
'hours_per_week': 78,
'education': 'Bachelors',
'marital_status': 'Married-civ-spouse',
'relationship': 'Husband',
'workclass': 'Self-emp-not-inc',
'occupation': 'abc',
}
我们如何预测并输出此人是否可能获得<50k(0)或> = 50k(1)?
答案 0 :(得分:2)
函数是predict,但我没有弄清楚如何直接输入一个示例数据(我试过numpy_input_fn和张量的dict)。
相反,使用wide_deep.py
中的输入函数将数据写入临时csv
文件然后读取它,可以使用predict
函数:
TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,'
'Husband,zyx,wvu,34,56,78,tsr,<=50K')
# Create temporary CSV file
input_csv = '/tmp/census_model/test.csv'
with tf.gfile.Open(input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT)
# restore model trained by wide_deep.py with same model_dir and model_type
model = wide_deep.build_estimator(FLAGS.model_dir, FLAGS.model_type)
pred_iter = model.predict(input_fn=lambda: wide_deep.input_fn(input_csv, 1, False, 1))
for pred in pred_iter:
# print(pred)
print(pred['classes'])
probability
中还有logits
,pred
等其他属性。
答案 1 :(得分:1)
Hookay,我现在可以回答这个问题。所以如果你想评估测试集的准确性,你可以按照接受的答案,但如果你想做出自己的预测,这里是步骤。
首先,构建一个新的input_fn
,注意您需要更改列和默认列值,因为标签列不在那里。
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_PREDICT_COLUMNS_DEFAULTS)
features = dict(zip(_PREDICT_COLUMNS, columns))
return features
def predict_input_fn(data_file):
assert tf.gfile.Exists(data_file), ('%s not found. Please make sure the path is correct.' % data_file)
dataset = tf.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_parallel_calls=5)
dataset = dataset.batch(1) # => This is very important to get the rank correct
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
return features
然后你可以通过
简单地调用它results = model.predict(
input_fn=lambda: predict_input_fn(data_file='test.csv')
)