创建随机索引为零的张量

时间:2019-07-24 11:57:37

标签: python tensorflow

现在我有一个张量random_row,我想创建一个新的张量,其形状为2-dim,其random_row为全零,所有其他行均为全一个,例如

random_row = [1, 3]  # random_row is a tensor itself

# new_tensor needs to be another tensor whose row in random_row to be all zeros
# we already know new_tensor's shape to be (4, 2)
new_tensor = [[1, 1], [0, 0], [1, 1], [0, 0]]

我该如何实现?如果有人可以提供帮助,我将不胜感激!

1 个答案:

答案 0 :(得分:0)

您可以使用他的功能来做到这一点:

import tensorflow as tf

def zero_rows(x, idx)
    # Turn row indices into a boolean mask
    n = tf.shape(x)[0]
    m = tf.scatter_nd(tf.expand_dims(idx, 1), tf.ones_like(idx, dtype=tf.bool), [n])
    # Select zeros where the indices are or the data elsewhere
    return tf.where(m, tf.zeros_like(x), x)

# Test
with tf.Graph().as_default(), tf.Session() as sess:
    y = zero_rows(x, idx)
    print(sess.run(y, feed_dict={x: [[1, 2], [3, 4], [5, 6], [7, 8]], idx: [1, 3]}))
    # [[1 2]
    #  [0 0]
    #  [5 6]
    #  [0 0]]