我正在尝试使用tfrecords来读取每个示例中具有可变长度列表的字段的批处理。数据可能是
example 1: [1, 2, 3]
example 2: [10, 11]
example 3: [100,200,300,400]
我一直在使用
tf.train.Feature(int64_list=tf.Int64List(value=x))
存储以上各项,即我将制作3个不同的tfrecords,其中x为[1,2,3]
,然后是[10,11]
,然后是[100,200,300,400]
这三个记录调用了SerializeToString()
方法,并通过TFRecordWriter
回读很棘手,我无法使用tf.FixedLenFeature
,所以我找到了tf.VarLenFeature
。这看起来非常好,当我以3的批量大小读取数据批次时,看起来我得到的是tf.SparseVectorValue
,其中索引的第0列是批处理中的示例编号,第1列是列表中的值,也就是说,它看起来像是我得到的(假设批量大小为3):
indices=[[0,0],
[0,1],
[0,2],
[1,0],
[1,1],
[2,0],
[2,1],
[2,2],
[2,3]]
values = [1,2,3,10,11,100,200,300,400]
但是现在我正在处理更多数据,我认为这不是我得到的。
我的问题是,批量填充像这样的可变长度列表时VarLenFeature会返回什么?它应该做我解释的吗?那么也许我有一个bug要找。
但如果它做了不同的事情,那么我该怎么做才能读回一批带有可变长度列表的数据呢?我需要知道每个列表的批处理中的示例编号,我可以使用每个列表的长度向tfrecord添加另一个字段。
- 编辑 -
我做了更多测试,我认为它的工作方式与我的想法相同。我必须在我更大的计划中遇到问题。如果有文档准确说明批量数据集应该返回tf.VarLenFeature
,那将是很好的,所以我可以确定我的上述解释是正确的。
下面是我正在尝试的一些测试代码:
import numpy as np
import tensorflow as tf
import random
def make_original_data(N):
data = []
for uid in range(N):
varlen = random.randint(2, 20)
varx = [random.randint(0,100) for kk in range(varlen)]
rec = {'uid': uid,
'A':random.randint(0,100),
'B':random.randint(0,100),
'V':varx,
'nV':varlen
}
data.append(rec)
return data
def rec2tfrec_example(rec):
def _int64_feat(value):
arr_value = np.empty([1], dtype=np.int64)
arr_value[0] = value
return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_value))
def _int64list_feat(values):
arr_values = np.empty([len(values)], dtype=np.int64)
arr_values[:] = values[:]
return tf.train.Feature(int64_list=tf.train.Int64List(value=arr_values))
feat = {
'uid': _int64_feat(rec['uid']),
'A': _int64_feat(rec['A']),
'B': _int64_feat(rec['B']),
'nV': _int64_feat(rec['nV']),
'V': _int64list_feat(rec['V'])
}
return tf.train.Example(features=tf.train.Features(feature=feat))
def parse_example(tfrec_serialized_string):
feat = {
'uid': tf.FixedLenFeature([], tf.int64),
'A': tf.FixedLenFeature([], tf.int64),
'B': tf.FixedLenFeature([], tf.int64),
'nV': tf.FixedLenFeature([], tf.int64),
'V': tf.VarLenFeature(tf.int64)
}
return tf.parse_example(tfrec_serialized_string, feat)
def to_serialized_tfrecs(data):
serialized = []
for rec in data:
example = rec2tfrec_example(rec)
serialized.append(example.SerializeToString())
return serialized
def write_tfrecs_to_file(fname, recs):
recwriter = tf.python_io.TFRecordWriter(fname)
for rec in recs:
recwriter.write(bytes(rec))
recwriter.close()
def check_batch(data_batch, tfres):
for ky in ['A', 'uid', 'B', 'nV']:
orig_data = np.array([rec[ky] for rec in data_batch], dtype=np.int64)
assert np.all(orig_data == tfres[ky]), "batch_idx=%d ky=%s orig=%s tf=%s" % \
(batch_idx, ky, orig_data, tfres[ky])
spTensorValue = tfres['V']
tf_example_in_batch = spTensorValue.indices[:,0]
tf_V = spTensorValue.values
tf_sum_nV = np.sum(tfres['nV'])
assert tf_sum_nV == len(tf_V), "tf_sum_nV=%d != len(tf_V)=%d" % (tf_sum_nV, len(tf_V))
ex_example_in_batch = np.empty((tf_sum_nV,), np.int64)
ex_V = np.empty((tf_sum_nV,), np.int64)
idx = 0
for example_in_batch, tf_num_this_example in enumerate(tfres['nV']):
num_this_example = data_batch[example_in_batch]['nV']
assert num_this_example == tf_num_this_example
ex_example_in_batch[idx:idx+num_this_example] = example_in_batch
ex_V[idx:idx+num_this_example] = np.array(data_batch[example_in_batch]['V'])
idx += num_this_example
assert np.all(ex_example_in_batch == tf_example_in_batch), "example in batch wrong, expected=%s != tf=%s" % (ex_example_in_batch, tf_example_in_batch)
assert np.all(ex_V == tf_V), "example in batch wrong, expected=%s != tf=%s" % (ex_V, tf_V)
def check_tfrecs(sess, tfrec_output_filename, data, N, batch_size):
dataset = tf.data.TFRecordDataset(tfrec_output_filename) \
.batch(batch_size) \
.map(parse_example, num_parallel_calls=2)
tf_iter = dataset.make_initializable_iterator()
get_next = tf_iter.get_next()
sess.run(tf_iter.initializer)
num_batches = N//batch_size
nextIdx = 0
for batch_idx in range(num_batches):
data_batch = [data[idx] for idx in range(nextIdx, nextIdx + batch_size)]
nextIdx += batch_size
tfres = sess.run(get_next)
check_batch(data_batch, tfres)
def main(N=1000, batch_size=5, tfrec_output_filename='tfrec_testing.tfrecords'):
tf.reset_default_graph()
data = make_original_data(N)
tfrec_strings = to_serialized_tfrecs(data)
write_tfrecs_to_file(tfrec_output_filename, tfrec_strings)
with tf.Session() as sess:
check_tfrecs(sess, tfrec_output_filename, data, N, batch_size)
if __name__ == '__main__':
main()