Tensorflow:如何操纵张量的subtensors?

时间:2017-10-28 21:04:43

标签: tensorflow

我想使用另一个张量作为索引来操纵张量的子指标(如列或行)。 所以给了我三个张量:

tensor = tf.constant([[1,2,3], [4,5,6]])
r = tf.constant(0)
new_row = tf.constant([-3,-2,-1])

我需要一些功能或适用于这三个张量的东西,给我

new_tensor = tf.constant([[-3,-2,-1],[4,5,6]])

所以我想用'new_row'替换张量'张量'的第r行。这甚至可能吗?

更新

好的,所以我发现以下解决方案可以动态替换矩阵中的列,也就是说,我们既不知道矩阵的尺寸,也不知道要替换的列的索引,也不知道图形构造期间的实际替换列时间。

import tensorflow as tf


# matrix: 2D-tensor of shape (m,n)
# new_column: 1D-tensor of shape m
# r: 0D-tensor with value from { 0,...,n-1 }
# Outputs 2D-tensor of shape (m,n) with the same values as matrix, except that the r-th column has been replaced by new_column
def replace_column(matrix, new_column, r):
    num_rows,num_cols = tf.unstack(tf.shape(matrix))
    index_row = tf.stack( [ tf.eye(num_cols,dtype=tf.float64)[r,:] ] )
    old_column = matrix[:,r]
    new = tf.matmul( tf.stack([new_column],axis=1), index_row )
    old = tf.matmul( tf.stack([old_column],axis=1), index_row )
    return (matrix-old)+new


matrix = [[1,2,3],[4,5,6],[7,8,9]]
column = [-1,-2,-3]
pos = 1

dynamic = tf.placeholder(tf.float64, shape=[None,None])
pos_tensor = tf.placeholder(tf.int32,shape=[])
column_tensor = tf.placeholder(dtype=tf.float64,shape=[None])

result_dynamic = replace_column(dynamic, column_tensor, pos_tensor)

with tf.Session() as sess:
    print "Input matrix, column, position: ", matrix, column, pos
    print "Dynamic result: ", sess.run([result_dynamic], { dynamic: matrix, pos_tensor: pos, column_tensor: column })

它使用外部产品操作来完成这项工作,这也是我无法将其推广到一般张量的原因(也因为我只需要它用于矩阵;-))。

1 个答案:

答案 0 :(得分:0)

import tensorflow as tf

sess = tf.InteractiveSession()

tensor = tf.constant([[1,2,3], [4,5,6]])
r = tf.constant(0)
new_row = tf.constant([-3,-2,-1])

shp1 = tensor.get_shape()

unpacked_tensor = tf.unstack(tensor, axis=0)
new_tensor_list = []
for iiR in list(range(shp1[0])):
    new_tensor_list.append(tf.where(tf.equal(r, iiR), new_row, unpacked_tensor[iiR]))

new_tensor = tf.stack(new_tensor_list, axis = 0)

print(new_tensor.eval())

输出:

[[-3 -2 -1]
 [ 4  5  6]]