Tensorflow 2.0:基于遮罩将更新分配给张量

时间:2019-04-22 15:42:11

标签: python tensorflow pytorch tensorflow2.0

我在使用新的Tensorflow 2.0 API(或通常的Tensorflow)重新实现Pytorch脚本时遇到一个小问题。在原始的Pytorch脚本中,张量使用布尔掩码(类似于numpy中的方式)进行更新:

# PyTorch
x_pred[empty_mask] = s_

x_pred是形状为(batch_size,81,9)的张量。该张量应在网络的前进步骤中更新几次。

s_是一个张量,包含由神经网络生成的softmax概率。形状取决于empty_mask中空白字段的数量。第一个维度是可变的,第二个维度始终是9。

empty_mask是形状为(batch_size,81)的张量。此张量适用于x_pred中的空字段。这些空字段应在每个向前的步骤中进行更新。

目前,我可以使用遮罩提取x_pred张量的相关部分。

extract = x_pred[empty_mask]

extractx_pred[empty_mask]具有相同的形状。

当我尝试以相同的样式更新Tensorflow张量时,出现以下错误消息:

# Tensorflow 2.0 
x_pred = tf.Variable(x, trainable=False)   # x is the input, so x_pred should be a copy an will be filled in the next steps

...

for i in range(max_empty_fields):

   ...
   s_ = self.softmax(previous_layers)  # Shape (???, 9)
   x_pred[empty_fields_mask] = s_      # Update x_pred

   ...


# ==> TypeError: only size-1 arrays can be converted to Python scalars

有人可以告诉我如何“就地”分配这些更新,以便直接更新x_pred吗?

非常感谢您。

更新

我找到了适合我的解决方案。首先,我必须将掩盖张量转换为x_pred的相关索引,然后才能使用新TF2 API提供的tf.tensor_scatter_nd_update()函数。

# Translate mask into indices using tf.where()
indices = tf.where(empty_fields_mask)

# Update the given indices of x_pred using tf.tensor_scatter_nd_update()
x_pred = tf.tensor_scatter_nd_update(x_pred, indices, s_)

0 个答案:

没有答案