使用输入fn

时间:2017-10-26 07:26:02

标签: tensorflow classification predict

我使用https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py中的教程代码,代码工作正常,直到我尝试进行预测而不是仅仅进行评估。我试图为预测创建另一个看起来像这样的函数(通过删除参数y):

def input_fn_predict(data_file, num_epochs, shuffle):
  """Input builder function."""
  df_data = pd.read_csv(
      tf.gfile.Open(data_file),
      names=CSV_COLUMNS,
      skipinitialspace=True,
      engine="python",
      skiprows=1)
  # remove NaN elements
  df_data = df_data.dropna(how="any", axis=0)
  labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
  return tf.estimator.inputs.pandas_input_fn( #removed paramter y
      x=df_data,
      batch_size=100,
      num_epochs=num_epochs,
      shuffle=shuffle,
      num_threads=5)

并称之为:

predictions = m.predict(
      input_fn=input_fn_predict(test_file_name, num_epochs=1, shuffle=True)
  )
  for i, p in enumerate(predictions):
      print(i, p)
  • 我做得对吗?
  • 为什么我得到预测81404而不是16282(测试文件中的行数)?
  • 每行包含以下内容:
  

{'概率':数组([0.78595656,0.21404342],dtype = float32),   'logits':array([ - 1.3007226],dtype = float32),'classes':array(['0'],   dtype = object),'class_ids':array([0]),'logistic':array([   0.21404341],dtype = float32)}

我如何阅读?

1 个答案:

答案 0 :(得分:11)

您需要设置shuffle=False,因为要预测新标签,您需要维护数据顺序。

下面是我运行预测的代码(我已经测试过了)。输入文件类似于测试数据(在csv中),但没有标签列。



    def predict_input_fn(data_file):
        global CSV_COLUMNS
        CSV_COLUMNS = CSV_COLUMNS[:-1]
        df_data = pd.read_csv(
            tf.gfile.Open(data_file),
            names=CSV_COLUMNS,
            skipinitialspace=True,
            engine='python',
            skiprows=1
        )

        # remove NaN elements
        df_data = df_data.dropna(how='any', axis=0)

        return tf.estimator.inputs.pandas_input_fn(
            x=df_data,
            num_epochs=1,
           shuffle=False
        )

要打电话:



    predict_file_name = 'tutorials/data/adult.predict'
    results = m.predict(
        input_fn=predict_input_fn(predict_file_name)
    )
    for result in results:
        print 'result: {}'.format(result)

一个样本的预测结果如下:



    {
        'probabilities': array([0.78595656, 0.21404342], dtype = float32),
        'logits': array([-1.3007226], dtype = float32),
        'classes': array(['0'], dtype = object),
        'class_ids': array([0]),
        'logistic': array([0.21404341], dtype = float32)
    }

每个字段的含义

  • '概率':数组([0.78595656,0.21404342],dtype = float32)
    它预测输出标签是0级(在这种情况下< = 50K) 信心0.78595656
  • 'logits':array([ - 1.3007226],dtype = float32)
    等式1 /(1 + e ^( - z))中z的值为-1.3。
  • 'classes':array(['0'],dtype = object)
    类标签为0