我正在尝试学习如何使用Estimator
API,使用input_fn
向Dataset
生成的输入图层提供feature_column
支持的输入。
我的代码看起来像
import tensorflow as tf import random
tf.logging.set_verbosity(tf.logging.DEBUG)
def input_fn():
def gen():
for i in range(100000):
for j in range(10):
yield {"in": str(j)}, [str(j+1)]
data = tf.data.Dataset.from_generator(gen, ({"in": tf.string}, tf.string))
data = data.batch(10)
iterator = data.make_one_shot_iterator()
return iterator.get_next()
vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_list(
key="in",
vocabulary_list=map(lambda i: str(i), range(11)))
embedding_column = tf.feature_column.embedding_column(
categorical_column=vocabulary_feature_column,
dimension=2)
with tf.Session() as sess:
print(sess.run(input_fn()))
classifier = tf.estimator.DNNClassifier(
feature_columns = [embedding_column],
hidden_units = [5,5],
n_classes = 11,
model_dir = '/tmp/predict/snap')
classifier.train(
input_fn=input_fn)
但运行它我
Traceback (most recent call last): File "predict.py", line 33, in input_fn=input_fn) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 302, in train loss = self._train_model(input_fn, hooks, saving_listeners) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 711, in _train_model features, labels, model_fn_lib.ModeKeys.TRAIN, self.config) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/estimator.py", line 694, in _call_model_fn model_fn_results = self._model_fn(features=features, **kwargs) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 334, in _model_fn config=config) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 190, in _dnn_model_fn logits = logit_fn(features=features, mode=mode) File "/usr/lib/python2.7/site-packages/tensorflow/python/estimator/canned/dnn.py", line 89, in dnn_logit_fn features=features, feature_columns=feature_columns) File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 230, in input_layer trainable=trainable) File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1837, in _get_dense_tensor inputs, weight_collections=weight_collections, trainable=trainable) File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 2123, in _get_sparse_tensors return _CategoricalColumn.IdWeightPair(inputs.get(self), None) File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1533, in get transformed = column._transform_feature(self) # pylint: disable=protected-access File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 2091, in _transform_feature input_tensor = _to_sparse_input(inputs.get(self.key)) File "/usr/lib/python2.7/site-packages/tensorflow/python/feature_column/feature_column.py", line 1631, in _to_sparse_input raise ValueError('Undefined input_tensor shape.') ValueError: Undefined input_tensor shape.
查看源代码我得到的印象是,categorical_column_with_vocabulary_list需要张量作为输出而不是字符串,但我很难理解如何使input_fn以正确的方式提供。
有谁知道我在这里做错了什么?
作为比较,以下代码完全正常:https://pastebin.com/28QUNAjA
修改
我注意到用tf.data.Dataset.from_generator
替换tf.data.Dataset.from_tensor_slices
会使代码运行。
即。以下实际上有效:
import tensorflow as tf
import random
tf.logging.set_verbosity(tf.logging.DEBUG)
def input_fn():
data = tf.data.Dataset.from_tensor_slices(({"in": map(lambda i: str(i), range(10))}, range(1,11)))
data = data.repeat(1000)
data = data.batch(10)
iterator = data.make_one_shot_iterator()
return iterator.get_next()
vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_list(
key="in",
vocabulary_list=map(lambda i: str(i), range(11)))
embedding_column = tf.feature_column.embedding_column(
categorical_column=vocabulary_feature_column,
dimension=2)
with tf.Session() as sess:
print(sess.run(input_fn()))
classifier = tf.estimator.DNNClassifier(
feature_columns = [embedding_column],
hidden_units = [5,5],
n_classes = 11,
model_dir = '/usr/local/google/home/zond/tmp/predict/snap')
classifier.train(
input_fn=input_fn)
这应该是一个错误,所以我创建了https://github.com/tensorflow/tensorflow/issues/15178。