使用Tensor Flow中的TF contrib预测器进行预测

时间:2017-11-17 02:06:47

标签: python tensorflow tensorflow-serving

使用以下代码训练人口普查数据集的python模型: https://gist.github.com/gaganmalhotra/8c40e7650f27cf3f894bad092fbe01ab

我成功地能够在python中训练/保存和加载模型。

但是,使用预测器进行预测会产生如下错误:

ValueError: Got unexpected keys in input_dict: set(['workclass', 'gender', 'marital_status', 'race', 'native_country', 'education', 'occupation'])

当传递给模型的数据只包含这些张量时,可以在我上面分享的要点中找到该代码。另请参阅下面的代码段 -

from tensorflow.contrib import predictor
export_dir = "/Users/Documents/SampleTF_projects/tempppp/1510877466/"
predict_fn = predictor.from_saved_model(export_dir, signature_def_key=None)

K_CATEGORICAL_COLUMNS = ["gender", "native_country", "education", "occupation", "workclass", "marital_status", "race"]

def test_ip(df):
  categorical_cols = {k: tf.SparseTensor(
      indices=[[i, 0] for i in range(df[k].size)],
      values=df[k].values,
      dense_shape=[df[k].size, 1])
                      for k in K_CATEGORICAL_COLUMNS}
  return categorical_cols

# Get a sample from training dataframe
input_test = df_train[2:3]

# Passing through the input function will return the dict of corresponding tensors
dict_test = test_ip(input_test)

#Now making the prediction
predictions = predict_fn(dict_test)   #<<<<<<Error caused at this line
print(predictions['probabilities'])

我不确定只有这些功能用于训练数据,现在我们使用相同的功能进行预测。

@ash

0 个答案:

没有答案