假设我有200个tfrecord文件,每个tfrecord文件中有60个数字(在我们的情况下为整数)。每次,我想从16个tfrecord文件中获得8个数字(随机选择)。 为此,我使用
all_file_name = 'tfrecord-*'
tf.train.match_filenames_once(all_file_name)
tf.random_shuffle(all_tfrecord_files)[0:16]
获取16个tfrecord文件。
这里
for i in range(k):
data_loaded = decode_from_tfrecords([file_k[i]], data_shape, n, min_after_dequeue, num_threads)
I.append(data_loaded)
我从16条tfrecord中的每条中生成8个数字。
但是,我发现输出num并非来自16 tfrecord。
然后,我删除tf.random.shuffe()
(即使用all_tfrecord_files[0:16]
)
这次,输出数字与tfrecord
值匹配。
详细信息可以在see here
中找到您能告诉我问题吗,谢谢!