如何在张量流中的索引处实现Numpy?

时间:2018-07-30 07:23:06

标签: numpy tensorflow

我打算在numpy中进行如下操作:

    mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
    index = np.array([[1,0,0],[0,1,0],[0,0,1]])
    mat[np.where(index>0)] = 100
    print(mat)

如何在TensorFlow中实现它?

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
indi = tf.where(tf_index>0)
tf_mat[indi] = -1   <===== not allowed 

2 个答案:

答案 0 :(得分:2)

假设您要创建一个带有一些替换元素的新张量,而不更新变量,则可以执行以下操作:

import numpy as np
import tensorflow as tf

mat = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
index = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
tf_mat = tf.constant(mat)
tf_index = tf.constant(index)
tf_mat = tf.where(tf_index > 0, -tf.ones_like(tf_mat), tf_mat)
with tf.Session() as sess:
    print(sess.run(tf_mat))

输出:

[[-1  2  3]
 [ 4 -1  6]
 [ 7  8 -1]]

答案 1 :(得分:1)

您可以通过tf.where获取索引,然后可以运行索引,或者使用tf.gather从原始数组中收集数据,或者使用tf.scatter_update更新原始数据,{ {1}}用于多维更新。

tf.scatter_nd_update

请注意,更新值的第一维尺寸应与idx相同。

请参阅https://www.tensorflow.org/api_guides/python