使用tf.feature_column.make_parse_spec和tf.contrib.data.Dataset

时间:2017-10-16 10:54:05

标签: python tensorflow tfrecord

我使用以下代码创建了一个简单的tfrecords文件。

data_rows = [
    ("Male", "White", 12332.0),
    ("Female", "Black", 232324.0),
    ("Female", "Other", 12313.0)
]


def float_feature(x):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[float(x)]))


def string_feature(x):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[x.strip().encode("ascii")]))

tfrecords_path = "simple.tfrecords"

with tf.python_io.TFRecordWriter(tfrecords_path) as writer:
    for gender, race, earnings in data_rows:
        example = tf.train.Example(features=tf.train.Features(feature={
            "gender": string_feature(gender),
            "race": string_feature(race),
            "earnings": float_feature(earnings)
        }))
    writer.write(example.SerializeToString())

现在我想使用tf.feature_column功能列从中创建一个tf.contrib.data.Dataset。我已经定义了这样的功能。

gender = tf.feature_column.categorical_column_with_vocabulary_list("gender", ["Female", "Male"])
race = tf.feature_column.categorical_column_with_vocabulary_list("race", ["White", "Black", "Other"])
earnings = tf.feature_column.numeric_column("earnings")
parse_spec = tf.feature_column.make_parse_example_spec([gender, race, earnings])

然后我尝试使用以下代码,使用我在上面定义的功能将原始tfrecords数据集的每一行应用一个解析函数。

raw_dataset = tf.contrib.data.TFRecordDataset(tfrecords_path)

def parse_example(example):
    parsed_features = tf.parse_single_example(example, parse_spec)
    return parsed_features

dataset = raw_dataset.map(parse_example)

这会打破以下堆栈跟踪。

<TFRecordDataset shapes: (), types: tf.string>
Traceback (most recent call last):
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 460, in make_tensor_proto
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 460, in <listcomp>
    str_values = [compat.as_bytes(x) for x in proto_values]
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/compat.py", line 65, in as_bytes
    (bytes_or_text,))
TypeError: Expected binary or unicode string, got <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f49714edf98>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "simpletfrecords.py", line 87, in <module>
    parse_with_dataset()
  File "simpletfrecords.py", line 53, in parse_with_dataset
    dataset = raw_dataset.map(parse_example)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 964, in map
    return MapDataset(self, map_func, num_threads, output_buffer_size)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1735, in __init__
    self._map_func.add_to_graph(ops.get_default_graph())
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/function.py", line 449, in add_to_graph
    self._create_definition_if_needed()
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/framework/function.py", line 168, in _create_definition_if_needed
    outputs = self._func(*inputs)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1726, in tf_map_func
    flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/data/python/ops/dataset_ops.py", line 1726, in <listcomp>
    flattened_ret = [ops.convert_to_tensor(t) for t in nest.flatten(ret)]
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 611, in convert_to_tensor
    as_ref=False)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 676, in internal_convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 121, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 102, in constant
    tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
  File "/home/john/opt/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/tensor_util.py", line 464, in make_tensor_proto
    "supported type." % (type(values), values))
TypeError: Failed to convert object of type <class 'tensorflow.python.framework.sparse_tensor.SparseTensor'> to Tensor. Contents: SparseTensor(indices=Tensor("ParseSingleExample/Slice_Indices_gender:0", shape=(?, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExample:2", shape=(?,), dtype=string), dense_shape=Tensor("ParseSingleExample/Squeeze_Shape_gender:0", shape=(1,), dtype=int64)). Consider casting elements to a supported type.

因此,解析函数需要一个字节字符串来读取这些特征,但由于某些我不理解的原因,它似乎已经是一个SparseTensor。

然后我尝试使用&#34; old&#34;解析相同的tfrecords文件。队列的方式。

filename_queue = tf.train.string_input_producer([tfrecords_path], num_epochs=1)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features_op = tf.parse_single_example(serialized_example, parse_spec)
...
sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

features = sess.run(features_op)

这给了我这个输出,所以它似乎工作,我猜。但我的问题是,我如何使用tf.contrib.data.Dataset?

来完成这项工作
{'gender': SparseTensorValue(indices=array([[0]]), values=array([b'Male'], dtype=object), dense_shape=array([1])), 'race': SparseTensorValue(indices=array([[0]]), values=array([b'White'], dtype=object), dense_shape=array([1])), 'earnings': array([ 12332.], dtype=float32)}

TLDR:我想使用tf.feature_column来定义我的功能,然后创建从tfrecord文件读取的输入函数,这些文件与tf.estimator类兼容。我希望最好使用tf.contrib.data.Dataset API。

0 个答案:

没有答案