TensorFlow DNNClassifier功能列

时间:2018-01-05 03:47:25

标签: python numpy tensorflow machine-learning

我已成功使用了来自tensorflow的mnist数据集,这是我从头开始编码的深度神经网络模型。现在我想尝试使用tensorflow中更简单的DNNClassifier函数。我面临两个问题,我似乎无法在互联网上找到解决方案

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


estimator = tf.estimator.DNNClassifier(
    feature_columns=[28,9], #WTH is this!?!!
    hidden_units=[28, 512, 9],
    optimizer=tf.train.GradientDescentOptimizer(
        learning_rate=0.1,
    ))
x,y =mnist.train.next_batch(500)


estimator.train(input_fn=x,y,steps = 100)
  1. 有一个参数绑定到DNNClassifier,即feature_columns。它的文档对我来说没有意义。什么是feature_columns?它在深层神经网络中扮演什么角色?应该给出什么类型的变量? (list,tuple,numpy array?)

  2. 如何将训练数据输入估算器?我无法解决feature_columns问题,因此我无法让估算工作,这意味着我无法让培训工作。但我相信我目前的编码是错误的做法。

2 个答案:

答案 0 :(得分:1)

该变量创建feature_columns,它指定模型的输入。不确定为什么数组将用于变量,它应该是tf.feature_column.numeric_column,因为所有输入要素都是mnist中的数字。

就数据输入而言,张量流似乎是专门实现输入功能,这使得数据的输入和格式化变得更加容易。

我发现文档有时很有帮助。看起来你已经在代码中组装了各种部分,而且并非所有部分都是兼容的。例如,固定估计器不接受学习速率输入。太多,隐藏单位的数量远远超过了必要的数量,除非你喜欢看着你的处理器将它的粉丝推到极限。

开发人员博客以清晰,彻底的方式撰写。我建议从博客开始学习如何使用'罐头'估算器。它非常简单易用,可以使机器学习适应您自己的数据集。

看到这个; https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html

答案 1 :(得分:1)

我试图使用SQLFlow的示例代码来解释什么是“ feature_columns”。

显示数据集:

SELECT * from iris.train limit 2;

-----------------------------
+--------------+-------------+--------------+-------------+-------+
| SEPAL LENGTH | SEPAL WIDTH | PETAL LENGTH | PETAL WIDTH | CLASS |
+--------------+-------------+--------------+-------------+-------+
|          6.4 |         2.8 |          5.6 |         2.2 |     2 |
|            5 |         2.3 |          3.3 |           1 |     1 |
+--------------+-------------+--------------+-------------+-------+

培训脚本:

SELECT *
FROM iris.train
TRAIN DNNClassifier
WITH n_classes = 3, hidden_units = [10, 20]
COLUMN sepal_length, sepal_width, petal_length, petal_width  /* This is "feature_columns" */
LABEL class
INTO sqlflow_models.my_dnn_model;

仅供参考:https://github.com/sql-machine-learning/sqlflow/blob/develop/doc/demo.md#training-a-dnnclassifier-and-run-prediction