Tensorflow数据:将功能应用于批处理

时间:2018-10-10 15:20:51

标签: python tensorflow tensorflow-datasets

我正在使用tf.data迭代大型文本语料库中的批处理。

我想仅将函数应用于数据的子集(或批处理的子集),而不是一个一元素地应用。 具体来说,我的数据迭代器会产生 query, reply分批处理。它们都是正数对,所以我只想仅对下一批的子集(在这种情况下,仅“回复”一批”进行混洗以生成随机的负数。

例如, 输入:

query1 reply1

query2 reply2

query3 reply3

...

输出:

  • 正对:query1 reply1(与输入相同)
  • 负对:query1 replyN(回复随机排列)

当然可以使用python随机播放文本,但是我想使用tf.data使其高效,因为数据大小太大。

1 个答案:

答案 0 :(得分:0)

假设您有queriesreplies作为两个张量。您需要的是我认为可以与原始批处理合并的内容。

batch_size = 10
def reply_shuffle(queries, replies):
   shuffled_indices = tf.random_uniform(minval=0, maxval=batch_size+1, shape=[batch_size], dtype=tf.int32)
   shuffled_replies = tf.gather_nd(replies, shuffled_indices) 
   return queries, shuffled_replies