tensorflow:如何实现这个复杂的mask操作

时间:2018-07-23 07:56:24

标签: tensorflow

我有一个张量t1(等级2),形状为[batch_size,data_size],如下所示:

|---------------|--------------|--------------|
|      1.0      |     2.0      |     3.0      |
|---------------|--------------|--------------|
|      4.0      |     5.0      |     6.0      |
|---------------|--------------|--------------|
|      7.0      |     8.0      |     9.0      |
|---------------|--------------|--------------|

和一个形状为[batch_size]的指标张量t2(等级1):

|---------------|--------------|--------------|
|      2        |     1        |     3        |
|---------------|--------------|--------------|

我想实现一个函数(t1,t2)以输出:

|---------------|--------------|--------------|
|      1.0      |     2.0      |     0.0      |
|---------------|--------------|--------------|
|      4.0      |     0.0      |     0.0      |
|---------------|--------------|--------------|
|      7.0      |     8.0      |     9.0      |
|---------------|--------------|--------------|

它看起来像将t1与t2定义的掩码相乘,但是我不知道如何实现。

1 个答案:

答案 0 :(得分:1)

我认为您正在寻找tf.sequence_maskSee here。这基本上实现了创建您想知道的蒙版。用法如下:

mask = tf.sequence_mask(t2, dtype=tf.float32)
result = t1 * mask

如果未指定dtype,则会创建一个布尔掩码,当尝试与t1相乘时可能会导致问题,这就是我们特意要求float32的原因。

如果t2中的最大元素可以小于data_size,则应使用

mask = tf.sequence_mask(t2, maxlen=data_size, dtype=tf.float32)

以防止t1mask之间的形状不匹配。