如何从Tensorflow中的多个对齐数据集中随机选择数据?

时间:2019-02-07 05:35:27

标签: python tensorflow-datasets

假设我们有2个组的6个文本文件,每个组由3个文件组成,例如,

  • 第1组:1.a,1.b,1.c
  • 第2组:2.a,2.b,2.c

给定了固定的阈值randrandom()模块的random,我希望得到3个张量:

  • x组:x_a,x_b,x_c

每个文件中的行数相同且对齐,并且x_a的第n行为:

  • 第1步:'<nth line from 1.a>' if rand < random() else '<nth line from 2.a>'

x_b和x_c的第n行也将是:

  • 第2步:<'nth line from 1.b>' if '<nth row of x_a from 1.a>' else '<nth line from 2.b>'
  • 步骤3:<'nth line from 1.c>' if '<nth row of x_a from 1.a>' else '<nth line from 2.c>'(遵循步骤2,但适用于x_c)

以便x_a,x_b和x_c都对齐。

我使用的工具是tf.data.TextLineDataset,能否告诉我如何进行随机选择并保持选择轨道?谢谢!

1 个答案:

答案 0 :(得分:0)

========================我的解决方案====================== ===

我提供了一个跟踪文件来指导这3个文件。仍然欢迎其他解决方案!

a1 = tf.data.TextLineDataset(afile1).map(...)
b1 = tf.data.TextLineDataset(bfile1).map(...)
c1 = tf.data.TextLineDataset(cfile1).map(...)
...
index = tf.data.TextLineDataset(track_file).map(lambda line: tf.string_to_number(line, tf.int32))
As = tf.data.Dataset.zip((index, a1, a2))
Bs = tf.data.Dataset.zip((index, b1, b2))
...
ax = As.map(lambda i, l, r: tf.where(i > 0, l, r))
bx = As.map(lambda i, l, r: tf.where(i > 0, l, r))
cx = As.map(lambda i, l, r: tf.where(i > 0, l, r))
...