我有两个张量(A,B),A的形状为(N,N,12),B的形状为(N,N)。
我想保留一些值并根据B.B删除A中的其他值,就像字典一样。
例如:
B[1][1] = 2
newArray[1][1] = A[1][1][2*4:2*4+4]
B[i][j] = n # n in [0:2]
newArray[i][j] = A[i][j][n*4:n*4+4]
assert(newArray.shape == (N,N,4))
如何在Tensorflow中编写代码?
非常感谢。
答案 0 :(得分:0)
我认为这可以满足您的要求
import tensorflow as tf
# Input data
A = tf.placeholder(tf.float32, (None, None, 12))
B = tf.placeholder(tf.int32, (None, None))
# Reshape 12-vectors into 3x4 inner matrices
A_shape = tf.shape(A)
rows, cols = A_shape[0], A_shape[1]
A_res = tf.reshape(A, (rows, cols, 3, 4))
# Make indices
ii, jj = tf.meshgrid(tf.range(rows), tf.range(cols), indexing='ij')
B_idx = tf.stack([ii, jj, B], axis=-1)
# Gather result
result = tf.gather_nd(A_res, B_idx)
# Test
with tf.Session() as sess:
A_val = [
[
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,],
[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,],
],
[
[24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,],
[36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,],
],
]
B_val = [
[2, 0],
[1, 2],
]
result_val = sess.run(result, feed_dict={ A: A_val, B: B_val })
print(result_val)
输出:
[[[ 8. 9. 10. 11.]
[12. 13. 14. 15.]]
[[28. 29. 30. 31.]
[44. 45. 46. 47.]]]