我已成功使用了来自tensorflow的mnist数据集,这是我从头开始编码的深度神经网络模型。现在我想尝试使用tensorflow中更简单的DNNClassifier函数。我面临两个问题,我似乎无法在互联网上找到解决方案
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
estimator = tf.estimator.DNNClassifier(
feature_columns=[28,9], #WTH is this!?!!
hidden_units=[28, 512, 9],
optimizer=tf.train.GradientDescentOptimizer(
learning_rate=0.1,
))
x,y =mnist.train.next_batch(500)
estimator.train(input_fn=x,y,steps = 100)
有一个参数绑定到DNNClassifier,即feature_columns。它的文档对我来说没有意义。什么是feature_columns?它在深层神经网络中扮演什么角色?应该给出什么类型的变量? (list,tuple,numpy array?)
如何将训练数据输入估算器?我无法解决feature_columns问题,因此我无法让估算工作,这意味着我无法让培训工作。但我相信我目前的编码是错误的做法。
答案 0 :(得分:1)
该变量创建feature_columns,它指定模型的输入。不确定为什么数组将用于变量,它应该是tf.feature_column.numeric_column,因为所有输入要素都是mnist中的数字。
就数据输入而言,张量流似乎是专门实现输入功能,这使得数据的输入和格式化变得更加容易。
我发现文档有时很有帮助。看起来你已经在代码中组装了各种部分,而且并非所有部分都是兼容的。例如,固定估计器不接受学习速率输入。太多,隐藏单位的数量远远超过了必要的数量,除非你喜欢看着你的处理器将它的粉丝推到极限。
开发人员博客以清晰,彻底的方式撰写。我建议从博客开始学习如何使用'罐头'估算器。它非常简单易用,可以使机器学习适应您自己的数据集。
看到这个; https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
答案 1 :(得分:1)
我试图使用SQLFlow的示例代码来解释什么是“ feature_columns”。
显示数据集:
SELECT * from iris.train limit 2;
-----------------------------
+--------------+-------------+--------------+-------------+-------+
| SEPAL LENGTH | SEPAL WIDTH | PETAL LENGTH | PETAL WIDTH | CLASS |
+--------------+-------------+--------------+-------------+-------+
| 6.4 | 2.8 | 5.6 | 2.2 | 2 |
| 5 | 2.3 | 3.3 | 1 | 1 |
+--------------+-------------+--------------+-------------+-------+
培训脚本:
SELECT *
FROM iris.train
TRAIN DNNClassifier
WITH n_classes = 3, hidden_units = [10, 20]
COLUMN sepal_length, sepal_width, petal_length, petal_width /* This is "feature_columns" */
LABEL class
INTO sqlflow_models.my_dnn_model;