通过常数因子缩放张量中的行集

时间:2018-04-30 08:40:16

标签: python tensorflow

TL; DR 如何将张量的一部分缩放2(tf列表中存在的行索引)

详细说明:

indices_of_scaling_ids:存储row_id列表 Tensor("Squeeze:0", dtype=int64, device=/device:GPU:0) [1,4,5,6,12]

emb_inputs = tf.nn.embedding_lookup(embedding, self.all_rows) #tensor with shape (batch_size=4, all_row_len, emb_size=128)

因此,对于每个self.all_rows,都会评估emb_inputs

面临的问题/挑战:我需要为emb_inputs中提到的每个row_id按2.0缩放indices_of_scaling_ids。 我尝试了各种拼接方法,但似乎无法找到一个好的解决方案。有人可以建议吗?感谢

N.B。 Tensorflow初学者

1 个答案:

答案 0 :(得分:1)

尝试这样的事情:

SCALE = 2
emb_inputs = ...
indices_of_scaling_ids = ...
emb_shape = tf.shape(emb_inputs)
# Select indices in boolean array
r = tf.range(emb_shape[1])
mask = tf.reduce_any(tf.equal(r[:, tf.newaxis], indices_of_scaling_ids), axis=1)
# Tile the mask
mask = tf.tile(mask[tf.newaxis, :, tf.newaxis], (emb_shape[0], 1, emb_shape[2]))
# Choose scaled or not depending on indices
result = tf.where(mask, SCALE * emb_inputs, emb_inputs)