我使用以下代码创建了一个简单的tfrecords文件。
data_rows = [
("Male", "White", 12332.0),
("Female", "Black", 232324.0),
("Female", "Other", 12313.0)
]
def float_feature(x):
return tf.train.Feature(float_list=tf.train.FloatList(value=[float(x)]))
def string_feature(x):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[x.strip().encode("ascii")]))
tfrecords_path = "simple.tfrecords"
with tf.python_io.TFRecordWriter(tfrecords_path) as writer:
for gender, race, earnings in data_rows:
example = tf.train.Example(features=tf.train.Features(feature={
"gender": string_feature(gender),
"race": string_feature(race),
"earnings": float_feature(earnings)
}))
writer.write(example.SerializeToString())
现在我想使用tf.feature_column功能列从中创建一个tf.contrib.data.Dataset。我已经定义了这样的功能。
gender = tf.feature_column.categorical_column_with_vocabulary_list("gender", ["Female", "Male"])
race = tf.feature_column.categorical_column_with_vocabulary_list("race", ["White", "Black", "Other"])
earnings = tf.feature_column.numeric_column("earnings")
parse_spec = tf.feature_column.make_parse_example_spec([gender, race, earnings])
然后我尝试使用以下代码,使用我在上面定义的功能将原始tfrecords数据集的每一行应用一个解析函数。
raw_dataset = tf.contrib.data.TFRecordDataset(tfrecords_path)
def parse_example(example):
parsed_features = tf.parse_single_example(example, parse_spec)
return parsed_features
dataset = raw_dataset.map(parse_example)
这会打破以下堆栈跟踪。
<TFRecordDataset shapes: (), types: tf.string>
Traceback (most recent call last):
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 460, in make_tensor_proto
str_values = [compat.as_bytes(x) for x in proto_values]
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 460, in <listcomp>
str_values = [compat.as_bytes(x) for x in proto_values]
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/compat.py", line 65, in as_bytes
(bytes_or_text,))
TypeError: Expected binary or unicode string, got <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f49714edf98>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "simpletfrecords.py", line 87, in <module>
parse_with_dataset()
File "simpletfrecords.py", line 53, in parse_with_dataset
dataset = raw_dataset.map(parse_example)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 964, in map
return MapDataset(self, map_func, num_threads, output_buffer_size)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1735, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 449, in add_to_graph
self._create_definition_if_needed()
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/framework/function.py", line 168, in _create_definition_if_needed
outputs = self._func(*inputs)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1726, in tf_map_func
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1726, in <listcomp>
flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 611, in convert_to_tensor
as_ref=False)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 676, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 121, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 102, in constant
tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 464, in make_tensor_proto
"supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("ParseSingleExample/Slice_Indices_gender:0", shape=(?, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExample:2", shape=(?,), dtype=string), dense_shape=Tensor("ParseSingleExample/Squeeze_Shape_gender:0", shape=(1,), dtype=int64)). Consider casting elements to a supported type.
因此,解析函数需要一个字节字符串来读取这些特征,但由于某些我不理解的原因,它似乎已经是一个SparseTensor。
然后我尝试使用&#34; old&#34;解析相同的tfrecords文件。队列的方式。
filename_queue = tf.train.string_input_producer([tfrecords_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features_op = tf.parse_single_example(serialized_example, parse_spec)
...
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
features = sess.run(features_op)
这给了我这个输出,所以它似乎工作,我猜。但我的问题是,我如何使用tf.contrib.data.Dataset?
来完成这项工作{'gender': SparseTensorValue(indices=array([[0]]), values=array([b'Male'], dtype=object), dense_shape=array([1])), 'race': SparseTensorValue(indices=array([[0]]), values=array([b'White'], dtype=object), dense_shape=array([1])), 'earnings': array([ 12332.], dtype=float32)}
TLDR:我想使用tf.feature_column来定义我的功能,然后创建从tfrecord文件读取的输入函数,这些文件与tf.estimator类兼容。我希望最好使用tf.contrib.data.Dataset API。