使用以下代码训练人口普查数据集的python模型: https://gist.github.com/gaganmalhotra/8c40e7650f27cf3f894bad092fbe01ab
我成功地能够在python中训练/保存和加载模型。
但是,使用预测器进行预测会产生如下错误:
ValueError: Got unexpected keys in input_dict: set(['workclass', 'gender', 'marital_status', 'race', 'native_country', 'education', 'occupation'])
当传递给模型的数据只包含这些张量时,可以在我上面分享的要点中找到该代码。另请参阅下面的代码段 -
from tensorflow.contrib import predictor
export_dir = "/Users/Documents/SampleTF_projects/tempppp/1510877466/"
predict_fn = predictor.from_saved_model(export_dir, signature_def_key=None)
K_CATEGORICAL_COLUMNS = ["gender", "native_country", "education", "occupation", "workclass", "marital_status", "race"]
def test_ip(df):
categorical_cols = {k: tf.SparseTensor(
indices=[[i, 0] for i in range(df[k].size)],
values=df[k].values,
dense_shape=[df[k].size, 1])
for k in K_CATEGORICAL_COLUMNS}
return categorical_cols
# Get a sample from training dataframe
input_test = df_train[2:3]
# Passing through the input function will return the dict of corresponding tensors
dict_test = test_ip(input_test)
#Now making the prediction
predictions = predict_fn(dict_test) #<<<<<<Error caused at this line
print(predictions['probabilities'])
我不确定只有这些功能用于训练数据,现在我们使用相同的功能进行预测。
@ash