我如何在Tensorflow中获得一些数组的值?

时间:2018-10-28 09:09:18

标签: python tensorflow

我有两个张量(A,B),A的形状为(N,N,12),B的形状为(N,N)。
我想保留一些值并根据B.B删除A中的其他值,就像字典一样。
例如:

B[1][1] = 2
newArray[1][1] = A[1][1][2*4:2*4+4]
B[i][j] = n  # n in [0:2]
newArray[i][j] = A[i][j][n*4:n*4+4]
assert(newArray.shape == (N,N,4))

如何在Tensorflow中编写代码?
非常感谢。

1 个答案:

答案 0 :(得分:0)

我认为这可以满足您的要求

import tensorflow as tf

# Input data
A = tf.placeholder(tf.float32, (None, None, 12))
B = tf.placeholder(tf.int32, (None, None))
# Reshape 12-vectors into 3x4 inner matrices
A_shape = tf.shape(A)
rows, cols = A_shape[0], A_shape[1]
A_res = tf.reshape(A, (rows, cols, 3, 4))
# Make indices
ii, jj = tf.meshgrid(tf.range(rows), tf.range(cols), indexing='ij')
B_idx = tf.stack([ii, jj, B], axis=-1)
# Gather result
result = tf.gather_nd(A_res, B_idx)
# Test
with tf.Session() as sess:
    A_val = [
        [
            [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,],
            [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,],
        ],
        [
            [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,],
            [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,],
        ],
    ]
    B_val = [
        [2, 0],
        [1, 2],
    ]
    result_val = sess.run(result, feed_dict={ A: A_val, B: B_val })
    print(result_val)

输出:

[[[ 8.  9. 10. 11.]
  [12. 13. 14. 15.]]

 [[28. 29. 30. 31.]
  [44. 45. 46. 47.]]]