如何基于预定义的行索引逐行更改张量的值

时间:2019-06-04 20:28:37

标签: python tensorflow

考虑一下,我有一个名为x的张量,其形状为[1, batch_size]。如果my_tensor中受尊重的值小于或等于零,我想以[batch_size, seq_length]的形式更改另一个名为x的张量的行。

我想我可以通过代表一个代码来更好地解释:

import tensorflow as tf

batch_size = 3
seq_length = 5

x = tf.constant([-1, 4, 0]) # size is [1, batch_size]

# select the indices of the rows to be changed
candidate_rows = tf.where(tf.less_equal(x, 0))

my_tensor = tf.random.uniform(shape=(batch_size, seq_length), minval=10, maxval=30, seed=123)

sess = tf.InteractiveSession()
print(sess.run(candidate_rows))
print(sess.run(my_tensor))

将产生:

candidate_rows = 
 [[0]
  [2]]

my_tensor = 
 [[10.816193 14.168425 11.83606  24.044014 24.146267]
 [17.929298 11.330187 15.837727 10.592653 29.098463]
 [10.122135 16.338099 24.35467  15.236387 10.991222]]

,我想将张量中的行[0]和[2]更改为另一个值,例如都等于1。

 [[1 1 1 1]
 [17.929298 11.330187 15.837727 10.592653 29.098463]
 [1 1 1 1 1]]

当我使用tf.where时,也许所有问题都会出现。感谢您的帮助:)

1 个答案:

答案 0 :(得分:1)

解决这个问题的一种方法是使用tf.where在两个张量的元素之间进行选择。

t = tf.ones(shape=my_tensor.shape, dtype=my_tensor.dtype)
my_tensor = tf.where(x > 0, my_tensor, t)