这是我想要做的简短例子。如何根据其他张量的值来访问一个张量?
import tensorflow as tf
import numpy as np
input_pl = tf.placeholder(tf.int32, shape=[None,2]))
mapping_pl = tf.placeholder(tf.int32, shape=[None,2]))
input1 = np.asarray([[1,1],
[2,2],
[3,3]])
mapping = np.asarray([[0,1],
[0,2],
[2,2]])
with tf.Graph().as_default():
output = .....
# add the 0th row of input1 with 1th row of output
# add the 0th row of input1 with 2th row of output
# add the 2th row of input1 with 2th row of output
sess = tf.Session()
output.eval(sess)
答案 0 :(得分:0)
也许是这样的
tf.reset_default_graph()
input1 = tf.constant(np.array([[1,1],[2,2],[3,3]]))
output = tf.Variable(np.ones(input1.get_shape()), dtype=input1.dtype)
mapping = tf.constant(np.asarray([[0,1],[0,2],[2,2]]))
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
input1_indices = tf.reshape(tf.slice(mapping, [0, 0], [3, 1]), [-1])
output_indices = tf.reshape(tf.slice(mapping, [0, 1], [3, 1]), [-1])
input1_values = tf.gather(input1, input1_indices)
print 'updating output rows', output_indices.eval()
print 'with values from input1 rows', input1_indices.eval()
print 'output before', sess.run(output)
sess.run(tf.scatter_update(output, output_indices, input1_values))
print 'output after', sess.run(output)