考虑一下,我有一个名为x
的张量,其形状为[1, batch_size]
。如果my_tensor
中受尊重的值小于或等于零,我想以[batch_size, seq_length]
的形式更改另一个名为x
的张量的行。
我想我可以通过代表一个代码来更好地解释:
import tensorflow as tf
batch_size = 3
seq_length = 5
x = tf.constant([-1, 4, 0]) # size is [1, batch_size]
# select the indices of the rows to be changed
candidate_rows = tf.where(tf.less_equal(x, 0))
my_tensor = tf.random.uniform(shape=(batch_size, seq_length), minval=10, maxval=30, seed=123)
sess = tf.InteractiveSession()
print(sess.run(candidate_rows))
print(sess.run(my_tensor))
将产生:
candidate_rows =
[[0]
[2]]
my_tensor =
[[10.816193 14.168425 11.83606 24.044014 24.146267]
[17.929298 11.330187 15.837727 10.592653 29.098463]
[10.122135 16.338099 24.35467 15.236387 10.991222]]
,我想将张量中的行[0]和[2]更改为另一个值,例如都等于1。
[[1 1 1 1]
[17.929298 11.330187 15.837727 10.592653 29.098463]
[1 1 1 1 1]]
当我使用tf.where
时,也许所有问题都会出现。感谢您的帮助:)
答案 0 :(得分:1)
解决这个问题的一种方法是使用tf.where
在两个张量的元素之间进行选择。
t = tf.ones(shape=my_tensor.shape, dtype=my_tensor.dtype)
my_tensor = tf.where(x > 0, my_tensor, t)