我在使用新的Tensorflow 2.0 API(或通常的Tensorflow)重新实现Pytorch脚本时遇到一个小问题。在原始的Pytorch脚本中,张量使用布尔掩码(类似于numpy中的方式)进行更新:
# PyTorch
x_pred[empty_mask] = s_
x_pred
是形状为(batch_size,81,9)的张量。该张量应在网络的前进步骤中更新几次。
s_
是一个张量,包含由神经网络生成的softmax概率。形状取决于empty_mask
中空白字段的数量。第一个维度是可变的,第二个维度始终是9。
empty_mask
是形状为(batch_size,81)的张量。此张量适用于x_pred
中的空字段。这些空字段应在每个向前的步骤中进行更新。
目前,我可以使用遮罩提取x_pred
张量的相关部分。
extract = x_pred[empty_mask]
extract
和x_pred[empty_mask]
具有相同的形状。
当我尝试以相同的样式更新Tensorflow张量时,出现以下错误消息:
# Tensorflow 2.0
x_pred = tf.Variable(x, trainable=False) # x is the input, so x_pred should be a copy an will be filled in the next steps
...
for i in range(max_empty_fields):
...
s_ = self.softmax(previous_layers) # Shape (???, 9)
x_pred[empty_fields_mask] = s_ # Update x_pred
...
# ==> TypeError: only size-1 arrays can be converted to Python scalars
有人可以告诉我如何“就地”分配这些更新,以便直接更新x_pred
吗?
非常感谢您。
更新
我找到了适合我的解决方案。首先,我必须将掩盖张量转换为x_pred
的相关索引,然后才能使用新TF2 API提供的tf.tensor_scatter_nd_update()函数。
# Translate mask into indices using tf.where()
indices = tf.where(empty_fields_mask)
# Update the given indices of x_pred using tf.tensor_scatter_nd_update()
x_pred = tf.tensor_scatter_nd_update(x_pred, indices, s_)