给定indices
形状为[batch_size, sequence_len]
,updates
形状为[batch_size, sequence_len, sampled_size]
,to_shape
形状为[batch_size, sequence_len, vocab_size]
,其中vocab_size
> ;> sampled_size
,我想使用tf.scatter
将updates
映射到to_shape
的巨张量,以便to_shape[bs, indices[bs, sz]] = updates[bs, sz]
。也就是说,我想逐行将updates
映射到to_shape
。请注意,sequence_len
和sampled_size
是标量张量,而其他则是固定的。我试着做以下事情:
new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)
但是我收到了一个错误:
ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]
您能告诉我如何正确使用scatter_nd
吗?提前谢谢!
答案 0 :(得分:3)
假设你有:
updates
的张量[batch_size, sequence_len, sampled_size]
。indices
的张量[batch_size, sequence_len, sampled_size]
。然后你做:
import tensorflow as tf
# Create updates and indices...
# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)
tf.scatter_nd
需要indices
张量,updates
张量和一些形状。 updates
是原始张量,形状只是所需的输出形状,因此[batch_size, sequence_len, vocab_size]
。现在,indices
更加复杂。由于您的输出具有3个维度(等级3),因此对于updates
中的每个元素,您需要3个索引来确定每个元素将在输出中的位置。因此,indices
参数的形状应与updates
相同,并且尺寸为3的额外尺寸。在这种情况下,我们希望第一个尺寸相同,但我们仍需指定3个指数。因此我们使用tf.meshgrid
生成我们需要的索引,并沿第三维度平铺它们(updates
的最后一个维度中每个元素向量的第一个和第二个索引是相同的)。最后,我们使用先前创建的映射索引来堆叠这些索引,并且我们具有完整的三维索引。
答案 1 :(得分:0)
我想您可能正在寻找这个。
def permute_batched_tensor(batched_x, batched_perm_ids):
indices = tf.tile(tf.expand_dims(batched_perm_ids, 2), [1,1,batched_x.shape[2]])
# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batched_x.shape[0]),
tf.range(batched_x.shape[2]), indexing="ij")
i1 = tf.tile(i1[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
i2 = tf.tile(i2[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
# Create final indices
idx = tf.stack([i1, indices, i2], axis=-1)
temp = tf.scatter_nd(idx, batched_x, batched_x.shape)
return temp