如何使用索引对张量执行操作?

时间:2016-05-02 03:57:40

标签: tensorflow

这是我想要做的简短例子。如何根据其他张量的值来访问一个张量?

import tensorflow as tf
import numpy as np


input_pl = tf.placeholder(tf.int32, shape=[None,2]))
mapping_pl = tf.placeholder(tf.int32, shape=[None,2]))

input1 = np.asarray([[1,1],
                    [2,2],
                    [3,3]])


mapping = np.asarray([[0,1],
                  [0,2],
                  [2,2]])

with tf.Graph().as_default():
output = .....

   # add the 0th row of input1 with 1th row of output
   # add the 0th row of input1 with 2th row of output
   # add the 2th row of input1 with 2th row of output

sess = tf.Session()
output.eval(sess)

1 个答案:

答案 0 :(得分:0)

也许是这样的

 
tf.reset_default_graph()
input1 = tf.constant(np.array([[1,1],[2,2],[3,3]]))

output = tf.Variable(np.ones(input1.get_shape()), dtype=input1.dtype)
mapping = tf.constant(np.asarray([[0,1],[0,2],[2,2]]))

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

input1_indices = tf.reshape(tf.slice(mapping, [0, 0], [3, 1]), [-1])
output_indices = tf.reshape(tf.slice(mapping, [0, 1], [3, 1]), [-1])
input1_values = tf.gather(input1, input1_indices)
print 'updating output rows', output_indices.eval()
print 'with values from input1 rows', input1_indices.eval()
print 'output before', sess.run(output)
sess.run(tf.scatter_update(output, output_indices, input1_values))
print 'output after', sess.run(output)