在张量流中,沿着0轴从张量中随机(子)采样k个条目

时间:2018-06-04 04:21:50

标签: python tensorflow

给定rank>=1 T的张量,我想从0轴上统一地随机抽样k个条目。

编辑:采样应作为延迟操作的计算图的一部分,并且每次调用时都应输出不同的随机条目。

例如,给定T等级2

T = tf.constant( \
     [[1,1,1,1,1],
      [2,2,2,2,2],
      [3,3,3,3,3],
      ....
      [99,99,99,99,99],
      [100,100,100,100,100]] \
     )

使用k=3,可能的输出为:

#output = \
#   [[34,34,34,34,34],
#    [6,6,6,6,6],
#    [72,72,72,72,72]]

如何在tensorflow中实现这一目标?

1 个答案:

答案 0 :(得分:2)

您可以在索引数组中使用随机shuffle:

获取第一个sample_num索引,并使用它们来选择输入的切片。

idxs = tf.range(tf.shape(input)[0])
ridxs = tf.random_shuffle(idx)[:sample_num]
rinput = tf.gather(input, ridxs)