我正在使用tf.data迭代大型文本语料库中的批处理。
我想仅将函数应用于数据的子集(或批处理的子集),而不是一个一元素地应用。
具体来说,我的数据迭代器会产生
query, reply
分批处理。它们都是正数对,所以我只想仅对下一批的子集(在这种情况下,仅“回复”一批”进行混洗以生成随机的负数。
例如, 输入:
query1 reply1
query2 reply2
query3 reply3
...
输出:
query1 reply1
(与输入相同)query1 replyN
(回复随机排列)当然可以使用python随机播放文本,但是我想使用tf.data使其高效,因为数据大小太大。
答案 0 :(得分:0)
假设您有queries
和replies
作为两个张量。您需要的是我认为可以与原始批处理合并的内容。
batch_size = 10
def reply_shuffle(queries, replies):
shuffled_indices = tf.random_uniform(minval=0, maxval=batch_size+1, shape=[batch_size], dtype=tf.int32)
shuffled_replies = tf.gather_nd(replies, shuffled_indices)
return queries, shuffled_replies