在列表中指定的索引处为张量分配零

时间:2018-11-04 10:11:26

标签: python tensorflow

例如,我必须张量

A = tf.Tensor(
       [[1.0986123 0.6931472 0.        0.6931472 0.       ]
        [0.        0.        0.        0.        0.       ]
        [3.7376697 3.7612002 3.7841897 3.8066626 3.8286414]], shape=(3, 5), dtype=float32)

B = tf.Tensor(
   [[2 1]
    [2 2]], shape=(2, 2), dtype=int64)

张量B保存张量A中的索引。我想将张量A中的每个值更新为索引列表B中列出的零。

因此,预期结果将是

tf.Tensor(
       [[1.0986123 0.6931472 0.        0.6931472 0.       ]
        [0.        0.        0.        0.        0.       ]
        [3.7376697 0 0 3.8066626 3.8286414]], shape=(3, 5), dtype=float32)

因此索引[2,1]和[2,2]的条目设置为0。

我查看了tf.assign,但它们只能用于tf.Variabletf.boolean_mask将是一个很好的方法,但是我不知道,也找不到我如何创建带有索引列表的布尔掩码。

我看了我能找到的张量流函数以及相关的S / O答案,但是找不到令人满意的解决方案。

1 个答案:

答案 0 :(得分:2)

您可以为此使用tf.scatter_nd_update。例如:

A = tf.Variable(
    [[1.0986123, 0.6931472, 0.       , 0.6931472, 0.       ],
     [0.       , 0.       , 0.       , 0.       , 0.       ],
     [3.7376697, 3.7612002, 3.7841897, 3.8066626, 3.8286414]], dtype=tf.float32)

B = tf.Variable(
    [[2, 1],
     [2, 2]], dtype=tf.int64)

C = tf.scatter_nd_update(A, B, tf.zeros(shape=tf.shape(B)[0]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(C))

A = tf.constant(
    [[1.0986123, 0.6931472, 0.       , 0.6931472, 0.       ],
     [0.       , 0.       , 0.       , 0.       , 0.       ],
     [3.7376697, 3.7612002, 3.7841897, 3.8066626, 3.8286414]], dtype=tf.float32)

B = tf.constant(
    [[2, 1],
     [2, 2]], dtype=tf.int64)

AV = tf.Variable(A)

C = tf.scatter_nd_update(AV, B, tf.zeros(shape=tf.shape(B)[0]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(C))