如何使用数据集API读取变量长度列表的TFRecords文件?

时间:2017-12-22 10:03:45

标签: python tensorflow tfrecord

我想使用Tensorflow的数据集API来读取变量长度列表的TFRecords文件。这是我的代码。

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def main1():
    # Write an array to TFrecord.
    # a is an array which contains lists of variant length.
    a = np.array([[0, 54, 91, 153, 177],
                 [0, 50, 89, 147, 196],
                 [0, 38, 79, 157],
                 [0, 49, 89, 147, 177],
                 [0, 32, 73, 145]])

    writer = tf.python_io.TFRecordWriter('file')

    for i in range(a.shape[0]): # i = 0 ~ 4
        x_train = a[i]
        feature = {'i': _int64_feature(np.array([i])), 'data': _int64_feature(x_train)}

        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))

        # Serialize to string and write on the file
        writer.write(example.SerializeToString())

    writer.close()

    # Check TFRocord file.
    record_iterator = tf.python_io.tf_record_iterator(path='file')
    for string_record in record_iterator:
        example = tf.train.Example()
        example.ParseFromString(string_record)

        i = (example.features.feature['i'].int64_list.value)
        data = (example.features.feature['data'].int64_list.value)
        #data = np.fromstring(data_string, dtype=np.int64)
        print(i, data)

    # Use Dataset API to read the TFRecord file.
    def _parse_function(example_proto):
        keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                            'data':tf.FixedLenFeature([], tf.int64)}
        parsed_features = tf.parse_single_example(example_proto, keys_to_features)
        return parsed_features['i'], parsed_features['data']

    ds = tf.data.TFRecordDataset('file')
    iterator = ds.map(_parse_function).make_one_shot_iterator()
    i, data = iterator.get_next()
    with tf.Session() as sess:
        print(i.eval())
        print(data.eval())

检查TFRecord文件

[0] [0, 54, 91, 153, 177]
[1] [0, 50, 89, 147, 196]
[2] [0, 38, 79, 157]
[3] [0, 49, 89, 147, 177]
[4] [0, 32, 73, 145]

但是当我尝试使用Dataset API读取TFRecord文件时,它显示以下错误。

  

tensorflow.python.framework.errors_impl.InvalidArgumentError:名称:   ,键:数据,索引:0。int64值的数量!=预期。   值大小:5但输出形状:[]

谢谢。
更新: 我尝试使用以下代码来读取数据集API的TFRecord,但它们都失败了。

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.FixedLenFeature([], tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))

def _parse_function(example_proto):
    keys_to_features = {'i'   :tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return parsed_features['i'], parsed_features['data']

ds = tf.data.TFRecordDataset('file')
iterator = ds.map(_parse_function).make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))

错误:

  

回溯(最近一次调用最后一次):文件" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py",   第468行,在make_tensor_proto中       str_values = [protat_values中x的compat.as_bytes(x)]文件" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py",   第468行       str_values = [protat_values中x的compat.as_bytes(x)]文件" /usr/local/lib/python3.5/dist-packages/tensorflow/python/util/compat.py",   第65行,as_bytes       (bytes_or_text,))TypeError:预期的二进制或unicode字符串,得到

     

在处理上述异常期间,发生了另一个异常:

     

回溯(最近一次呼叫最后):文件" 2tfrecord.py",第126行,在          main1()文件" 2tfrecord.py",第72行,在main1中       iterator = ds.map(_parse_function).make_one_shot_iterator()File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py",   第712行,在地图中       返回MapDataset(self,map_func)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py",   第1385行, init       self._map_func.add_to_graph(ops.get_default_graph())File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py",   第486行,在add_to_graph中       self._create_definition_if_needed()File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py",   第321行,在_create_definition_if_needed中       self._create_definition_if_needed_impl()File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py",   第338行,在_create_definition_if_needed_impl中       outputs = self._func(* inputs)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py",   第1376行,在tf_map_func中       flattened_ret = [ops.convert_to_tensor(t)for nest.flatten(ret)]文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py" ;,   第1376行,in       flattened_ret = [ops.convert_to_tensor(t)for nest.flatten(ret)]文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py" ;,   第836行,在convert_to_tensor中       as_ref = False)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py",   第926行,在internal_convert_to_tensor中       ret = conversion_func(value,dtype = dtype,name = name,as_ref = as_ref)文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py" ;,   第229行,在_constant_tensor_conversion_function中       return constant(v,dtype = dtype,name = name)File" /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/constant_op.py",   208行,常数       value,dtype = dtype,shape = shape,verify_shape = verify_shape))文件   " /usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_util.py" ;,   第472行,在make_tensor_proto中       "支持的类型。" %(类型(值),值))TypeError:无法将类型的对象转换为Tensor。   内容:   SparseTensor(指数=张量(" ParseSingleExample / Slice_Indices_i:0&#34 ;,   shape =(?,1),dtype = int64),   值=张量(" ParseSingleExample / ParseExample / ParseExample:3&#34 ;,   shape =(?,),dtype = int64),   dense_shape = Tensor(" ParseSingleExample / Squeeze_Shape_i:0",shape =(1,),   D型= int64类型))。考虑将元素转换为支持的类型。

Python版本:3.5.2
Tensorflow版本:1.4.1

2 个答案:

答案 0 :(得分:11)

经过几个小时的搜索和尝试,我相信答案会出现。以下是我的代码。

def _int64_feature(value):
    # value must be a numpy array.
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value.flatten()))

# Write an array to TFrecord.
# a is an array which contains lists of variant length.
a = np.array([[0, 54, 91, 153, 177],
              [0, 50, 89, 147, 196],
              [0, 38, 79, 157],
              [0, 49, 89, 147, 177],
              [0, 32, 73, 145]])

writer = tf.python_io.TFRecordWriter('file')

for i in range(a.shape[0]): # i = 0 ~ 4
    x_train = np.array(a[i])
    feature = {'i'   : _int64_feature(np.array([i])), 
               'data': _int64_feature(x_train)}

    # Create an example protocol buffer
    example = tf.train.Example(features=tf.train.Features(feature=feature))

    # Serialize to string and write on the file
    writer.write(example.SerializeToString())

writer.close()

# Check TFRocord file.
record_iterator = tf.python_io.tf_record_iterator(path='file')
for string_record in record_iterator:
    example = tf.train.Example()
    example.ParseFromString(string_record)

    i = (example.features.feature['i'].int64_list.value)
    data = (example.features.feature['data'].int64_list.value)
    print(i, data)

# Use Dataset API to read the TFRecord file.
filenames = ["file"]
dataset = tf.data.TFRecordDataset(filenames)
def _parse_function(example_proto):
    keys_to_features = {'i':tf.VarLenFeature(tf.int64),
                        'data':tf.VarLenFeature(tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, keys_to_features)
    return tf.sparse_tensor_to_dense(parsed_features['i']), \
           tf.sparse_tensor_to_dense(parsed_features['data'])
# Parse the record into tensors.
dataset = dataset.map(_parse_function)
# Shuffle the dataset
dataset = dataset.shuffle(buffer_size=1)
# Repeat the input indefinitly
dataset = dataset.repeat()  
# Generate batches
dataset = dataset.batch(1)
# Create a one-shot iterator
iterator = dataset.make_one_shot_iterator()
i, data = iterator.get_next()
with tf.Session() as sess:
    print(sess.run([i, data]))
    print(sess.run([i, data]))
    print(sess.run([i, data]))

有几点需要注意 这个SO问题有很多帮助 2. tf.VarLenFeature将返回SparseTensor,因此,使用tf.sparse_tensor_to_dense转换为密集张量是必要的。
3.在我的代码中,parse_single_example()无法替换为parse_example(),并且会让我感到困扰一天。我不知道为什么parse_example()没有成功。如果有人知道原因,请赐教。

答案 1 :(得分:2)

错误很简单。您的data不是FixedLenFeature VarLenFeature。替换你的行:

 'data':tf.FixedLenFeature([], tf.int64)}

 'data':tf.VarLenFeature(tf.int64)}

此外,当您调用print(i.eval())print(data.eval())时,您正在调用迭代器两次。第一个print将打印0,但第二个将打印第二行[ 0, 50, 89, 147, 196]的值。您可以print(sess.run([i, data]))从同一行获取idata