Tensorflow Estimator预测深度宽模型

时间:2017-12-05 07:02:56

标签: python tensorflow

有人可以举一个关于如何在tf.estimator上致电input_fn的明确示例吗?

我不清楚两件事情

  1. 当输入文件没有标签列时,如何调整 //purchase order update form public function update($id){ $suppliers = Supplier::all(); $categories = Category::all(); $purchaseorder_id = $id; $purchaseorder= Purchase_Order::where('id', '=', $purchaseorder_id)->get(); $purchaseorder_details_id = Purchase_Order_Details::pluck('purchase_order_id'); $purchaseorder_details = Purchase_Order_Details::where('purchase_order_id', '=', $id)->get(); //$purchaseorder_details = Purchase_Order_Details::with('purchaseorder')->get(); //dd($purchaseorder_details); return view('user/purchaseorder.edit')->with(['purchaseorder' => $purchaseorder, 'purchaseorder_details' => $purchaseorder_details, 'suppliers' => $suppliers, 'categories' => $categories]); }
  2. 如何导出预测结果
  3. 我正在研究deep_wide模型。

1 个答案:

答案 0 :(得分:0)

我会回答自己并希望这可以帮助有类似问题的人

对于问题1,是的,我们需要构建一个新的输入函数,并且很可能它将具有不同数量的列作为输入csv,因为我们正在删除标签列

def parse_csv(value):
    print('Parsing', data_file)
    columns = tf.decode_csv(value, record_defaults=_PREDICT_COLUMNS_DEFAULTS)
    features = dict(zip(_PREDICT_COLUMNS, columns))

    return features

def predict_input_fn(data_file):
    assert tf.gfile.Exists(data_file), ('%s not found. Please make sure the path is correct.' % data_file)

    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(parse_csv, num_parallel_calls=5)
    dataset = dataset.batch(1) # this is very important to keep the rank right
    iterator = dataset.make_one_shot_iterator()
    features = iterator.get_next()
    return features

然后对于问题2,您使用新的input_fn来预测结果

def predict(model):
    start_from_id = 892
    test_csv = []

    results = model.predict(
        input_fn=lambda: predict_input_fn(data_file='test.csv')
    )

    # for result in results:
    #     print 'result: {}'.format(result)