在tensorflow-r1.2

时间:2017-07-18 09:41:09

标签: python tensorflow deep-learning python-3.4 tensor

给定indices形状为[batch_size, sequence_len]updates形状为[batch_size, sequence_len, sampled_size]to_shape形状为[batch_size, sequence_len, vocab_size],其中vocab_size> ;> sampled_size,我想使用tf.scatterupdates映射到to_shape的巨张量,以便to_shape[bs, indices[bs, sz]] = updates[bs, sz]。也就是说,我想逐行将updates映射到to_shape。请注意,sequence_lensampled_size是标量张量,而其他则是固定的。我试着做以下事情:

new_tensor = tf.scatter_nd(tf.expand_dims(indices, axis=2), updates, to_shape)

但是我收到了一个错误:

ValueError: The inner 2 dimension of output.shape=[?,?,?] must match the inner 1 dimension of updates.shape=[80,50,?]: Shapes must be equal rank, but are 2 and 1 for .... with input shapes: [80, 50, 1], [80, 50,?], [3]

您能告诉我如何正确使用scatter_nd吗?提前谢谢!

2 个答案:

答案 0 :(得分:3)

假设你有:

  • 形状为updates的张量[batch_size, sequence_len, sampled_size]
  • 形状为indices的张量[batch_size, sequence_len, sampled_size]

然后你做:

import tensorflow as tf

# Create updates and indices...

# Create additional indices
i1, i2 = tf.meshgrid(tf.range(batch_size),
                     tf.range(sequence_len), indexing="ij")
i1 = tf.tile(i1[:, :, tf.newaxis], [1, 1, sampled_size])
i2 = tf.tile(i2[:, :, tf.newaxis], [1, 1, sampled_size])
# Create final indices
idx = tf.stack([i1, i2, indices], axis=-1)
# Output shape
to_shape = [batch_size, sequence_len, vocab_size]
# Get scattered tensor
output = tf.scatter_nd(idx, updates, to_shape)

tf.scatter_nd需要indices张量,updates张量和一些形状。 updates是原始张量,形状只是所需的输出形状,因此[batch_size, sequence_len, vocab_size]。现在,indices更加复杂。由于您的输出具有3个维度(等级3),因此对于updates中的每个元素,您需要3个索引来确定每个元素将在输出中的位置。因此,indices参数的形状应与updates相同,并且尺寸为3的额外尺寸。在这种情况下,我们希望第一个尺寸相同,但我们仍需指定3个指数。因此我们使用tf.meshgrid生成我们需要的索引,并沿第三维度平铺它们(updates的最后一个维度中每个元素向量的第一个和第二个索引是相同的)。最后,我们使用先前创建的映射索引来堆叠这些索引,并且我们具有完整的三维索引。

答案 1 :(得分:0)

我想您可能正在寻找这个。

def permute_batched_tensor(batched_x, batched_perm_ids):
    indices = tf.tile(tf.expand_dims(batched_perm_ids, 2), [1,1,batched_x.shape[2]])

    # Create additional indices
    i1, i2 = tf.meshgrid(tf.range(batched_x.shape[0]),
                     tf.range(batched_x.shape[2]), indexing="ij")
    i1 = tf.tile(i1[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    i2 = tf.tile(i2[:, tf.newaxis, :], [1, batched_x.shape[1], 1])
    # Create final indices
    idx = tf.stack([i1, indices, i2], axis=-1)
    temp = tf.scatter_nd(idx, batched_x, batched_x.shape)
    return temp