tensorflow索引如何工作

时间:2016-05-05 23:18:44

标签: tensorflow

我无法理解张量流的基本概念。索引如何用于张量读/写操作?为了使这一点具体化,如何将以下numpy示例转换为tensorflow(使用张量来分配数组,索引和值):

x = np.zeros((3, 4))
row_indices = np.array([1, 1, 2])
col_indices = np.array([0, 2, 3])
x[row_indices, col_indices] = 2
x

带输出:

array([[ 0.,  0.,  0.,  0.],
       [ 2.,  0.,  2.,  0.],
       [ 0.,  0.,  0.,  2.]])

......和......

x[row_indices, col_indices] = np.array([5, 4, 3])
x

带输出:

array([[ 0.,  0.,  0.,  0.],
       [ 5.,  0.,  4.,  0.],
       [ 0.,  0.,  0.,  3.]])

......最后......

y = x[row_indices, col_indices]
y

带输出:

array([ 5.,  4.,  3.])

1 个答案:

答案 0 :(得分:10)

github问题#206可以很好地支持这一点,同时你必须采用冗长的解决方法

第一个例子可以通过tf.select来完成,tf.reset_default_graph() row_indices = tf.constant([1, 1, 2]) col_indices = tf.constant([0, 2, 3]) x = tf.zeros((3, 4)) sess = tf.InteractiveSession() # get list of ((row1, col1), (row2, col2), ..) coords = tf.transpose(tf.pack([row_indices, col_indices])) # get tensor with 1's at positions (row1, col1),... binary_mask = tf.sparse_to_dense(coords, x.get_shape(), 1) # convert 1/0 to True/False binary_mask = tf.cast(binary_mask, tf.bool) twos = 2*tf.ones(x.get_shape()) # make new x out of old values or 2, depending on mask x = tf.select(binary_mask, twos, x) print x.eval() 通过从一个或另一个中选择每个元素来组合两个相同形状的张量

[[ 0.  0.  0.  0.]
 [ 2.  0.  2.  0.]
 [ 0.  0.  0.  2.]]

给出

scatter_update

第二个可以使用scatter_update完成,除了dynamic_stitch仅支持线性索引并处理变量。所以你可以创建一个临时变量并像这样使用重塑。 (为了避免使用# get linear indices linear_indices = row_indices*x.get_shape()[1]+col_indices # turn 'x' into 1d variable since "scatter_update" supports linear indexing only x_flat = tf.Variable(tf.reshape(x, [-1])) # no automatic promotion, so make updates float32 to match x updates = tf.constant([5, 4, 3], dtype=tf.float32) sess.run(tf.initialize_all_variables()) sess.run(tf.scatter_update(x_flat, linear_indices, updates)) # convert back into original shape x = tf.reshape(x_flat, x.get_shape()) print x.eval() 的变量,请参阅结尾)

[[ 0.  0.  0.  0.]
 [ 5.  0.  4.  0.]
 [ 0.  0.  0.  3.]]

给出

gather_nd

最后,print tf.gather_nd(x, coords).eval() 已支持第三个示例,您可以编写

[ 5.  4.  3.]

获得

x[cols,rows]=newvals

编辑,5月6日

更新select可以在不使用变量(在会话运行调用之间占用内存)的情况下完成,方法是使用sparse_to_densedynamic_stitch来获取稀疏值的向量,或者依赖于{{ 1}}

sess = tf.InteractiveSession()
x = tf.zeros((3, 4))
row_indices = tf.constant([1, 1, 2])
col_indices = tf.constant([0, 2, 3])

# no automatic promotion, so specify float type
replacement_vals = tf.constant([5, 4, 3], dtype=tf.float32)

# convert to linear indexing in row-major form
linear_indices = row_indices*x.get_shape()[1]+col_indices
x_flat = tf.reshape(x, [-1])

# use dynamic stitch, it merges the array by taking value either
# from array1[index1] or array2[index2], if indices conflict,
# the later one is used 
unchanged_indices = tf.range(tf.size(x_flat))
changed_indices = linear_indices
x_flat = tf.dynamic_stitch([unchanged_indices, changed_indices],
                           [x_flat, replacement_vals])
x = tf.reshape(x_flat, x.get_shape())
print x.eval()