TL; DR 如何将张量的一部分缩放2(tf列表中存在的行索引)
详细说明:
indices_of_scaling_ids
:存储row_id列表
Tensor("Squeeze:0", dtype=int64, device=/device:GPU:0)
[1,4,5,6,12]
emb_inputs = tf.nn.embedding_lookup(embedding, self.all_rows)
#tensor with shape (batch_size=4, all_row_len, emb_size=128)
因此,对于每个self.all_rows
,都会评估emb_inputs
。
面临的问题/挑战:我需要为emb_inputs
中提到的每个row_id按2.0
缩放indices_of_scaling_ids
。
我尝试了各种拼接方法,但似乎无法找到一个好的解决方案。有人可以建议吗?感谢
N.B。 Tensorflow初学者
答案 0 :(得分:1)
尝试这样的事情:
SCALE = 2
emb_inputs = ...
indices_of_scaling_ids = ...
emb_shape = tf.shape(emb_inputs)
# Select indices in boolean array
r = tf.range(emb_shape[1])
mask = tf.reduce_any(tf.equal(r[:, tf.newaxis], indices_of_scaling_ids), axis=1)
# Tile the mask
mask = tf.tile(mask[tf.newaxis, :, tf.newaxis], (emb_shape[0], 1, emb_shape[2]))
# Choose scaled or not depending on indices
result = tf.where(mask, SCALE * emb_inputs, emb_inputs)