每个示例中具有可变长度数据的批处理的tensorflow tfrecords?

时间:2018-03-15 15:16:09

标签: python tensorflow

我正在尝试使用tfrecords来读取每个示例中具有可变长度列表的字段的批处理。数据可能是

example 1: [1,    2,  3] 
example 2: [10,  11]
example 3: [100,200,300,400]

我一直在使用

tf.train.Feature(int64_list=tf.Int64List(value=x))

存储以上各项,即我将制作3个不同的tfrecords,其中x为[1,2,3],然后是[10,11],然后是[100,200,300,400]

这三个记录调用了SerializeToString()方法,并通过TFRecordWriter

附加到文件中

回读很棘手,我无法使用tf.FixedLenFeature,所以我找到了tf.VarLenFeature。这看起来非常好,当我以3的批量大小读取数据批次时,看起来我得到的是tf.SparseVectorValue,其中索引的第0列是批处理中的示例编号,第1列是列表中的值,也就是说,它看起来像是我得到的(假设批量大小为3):

indices=[[0,0],
         [0,1],
         [0,2],
         [1,0],
         [1,1],
         [2,0],
         [2,1],
         [2,2],
         [2,3]]
 values = [1,2,3,10,11,100,200,300,400]

但是现在我正在处理更多数据,我认为这不是我得到的。

我的问题是,批量填充像这样的可变长度列表时VarLenFeature会返回什么?它应该做我解释的吗?那么也许我有一个bug要找。

但如果它做了不同的事情,那么我该怎么做才能读回一批带有可变长度列表的数据呢?我需要知道每个列表的批处理中的示例编号,我可以使用每个列表的长度向tfrecord添加另一个字段。

- 编辑 -

我做了更多测试,我认为它的工作方式与我的想法相同。我必须在我更大的计划中遇到问题。如果有文档准确说明批量数据集应该返回tf.VarLenFeature,那将是很好的,所以我可以确定我的上述解释是正确的。

下面是我正在尝试的一些测试代码:

import numpy as np
import tensorflow as tf
import random


def make_original_data(N):
    data = []
    for uid in range(N):
        varlen = random.randint(2, 20)
        varx = [random.randint(0,100) for kk in range(varlen)] 
        rec = {'uid': uid,
               'A':random.randint(0,100),
               'B':random.randint(0,100),
               'V':varx,
               'nV':varlen
            }
        data.append(rec)
    return data


def rec2tfrec_example(rec):
    def _int64_feat(value):
        arr_value = np.empty([1], dtype=np.int64)
        arr_value[0] = value
        return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_value))

    def _int64list_feat(values):
        arr_values = np.empty([len(values)], dtype=np.int64)
        arr_values[:] = values[:]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_values))

    feat = {
        'uid': _int64_feat(rec['uid']),
        'A':   _int64_feat(rec['A']),
        'B':   _int64_feat(rec['B']),
        'nV':  _int64_feat(rec['nV']),
        'V':   _int64list_feat(rec['V'])
        }

    return tf.train.Example(features=tf.train.Features(feature=feat))


def parse_example(tfrec_serialized_string):
    feat = {
        'uid': tf.FixedLenFeature([], tf.int64),
        'A': tf.FixedLenFeature([], tf.int64),
        'B': tf.FixedLenFeature([], tf.int64),
        'nV': tf.FixedLenFeature([], tf.int64),
        'V': tf.VarLenFeature(tf.int64)
    }
    return tf.parse_example(tfrec_serialized_string, feat)


def to_serialized_tfrecs(data):
    serialized = []
    for rec in data:
        example = rec2tfrec_example(rec)
        serialized.append(example.SerializeToString())
    return serialized


def write_tfrecs_to_file(fname, recs):
        recwriter = tf.python_io.TFRecordWriter(fname)
        for rec in recs:
            recwriter.write(bytes(rec))
        recwriter.close()


def check_batch(data_batch, tfres):
    for ky in ['A', 'uid', 'B', 'nV']:
        orig_data = np.array([rec[ky] for rec in data_batch], dtype=np.int64)
        assert np.all(orig_data == tfres[ky]), "batch_idx=%d ky=%s orig=%s tf=%s" % \
            (batch_idx, ky, orig_data, tfres[ky])
    spTensorValue = tfres['V']
    tf_example_in_batch = spTensorValue.indices[:,0]
    tf_V = spTensorValue.values
    tf_sum_nV = np.sum(tfres['nV'])
    assert tf_sum_nV == len(tf_V), "tf_sum_nV=%d != len(tf_V)=%d" % (tf_sum_nV, len(tf_V))
    ex_example_in_batch = np.empty((tf_sum_nV,), np.int64)
    ex_V = np.empty((tf_sum_nV,), np.int64)
    idx = 0
    for example_in_batch, tf_num_this_example in enumerate(tfres['nV']):
        num_this_example = data_batch[example_in_batch]['nV']
        assert num_this_example == tf_num_this_example
        ex_example_in_batch[idx:idx+num_this_example] = example_in_batch
        ex_V[idx:idx+num_this_example] = np.array(data_batch[example_in_batch]['V'])
        idx += num_this_example
    assert np.all(ex_example_in_batch == tf_example_in_batch), "example in batch wrong, expected=%s != tf=%s" % (ex_example_in_batch, tf_example_in_batch)
    assert np.all(ex_V == tf_V), "example in batch wrong, expected=%s != tf=%s" % (ex_V, tf_V)


def check_tfrecs(sess, tfrec_output_filename, data, N, batch_size):
    dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
                     .batch(batch_size) \
                     .map(parse_example, num_parallel_calls=2)
    tf_iter = dataset.make_initializable_iterator()
    get_next = tf_iter.get_next()

    sess.run(tf_iter.initializer)
    num_batches = N//batch_size
    nextIdx = 0
    for batch_idx in range(num_batches):
        data_batch = [data[idx] for idx in range(nextIdx, nextIdx + batch_size)]
        nextIdx += batch_size
        tfres = sess.run(get_next)
        check_batch(data_batch, tfres)


def main(N=1000, batch_size=5, tfrec_output_filename='tfrec_testing.tfrecords'):
    tf.reset_default_graph()
    data = make_original_data(N)
    tfrec_strings = to_serialized_tfrecs(data)
    write_tfrecs_to_file(tfrec_output_filename, tfrec_strings)
    with tf.Session() as sess:
        check_tfrecs(sess, tfrec_output_filename, data, N, batch_size)

if __name__ == '__main__':
    main()

0 个答案:

没有答案