tensorflow:沿第二维切割张量

时间:2017-07-29 03:22:02

标签: python tensorflow

我有张量 X ,其形状为(无,56,300,1),另一张张 y ,其形状为< strong>(无,15),这些张量的第一个维度是batch_size,我想用y作为索引得到张量z,z的形状是(None,15,300,1) 即可。有没有可行的方法来做到这一点?

我写了一个简单的代码来测试,因为我发现这对我来说很难,因为在实践中我不知道batch_size(这些张量的第一维是),

这是我的测试代码:

import numpy as np
import tensorflow as tf

# In this test code , batch_size is 4.
# params' shape is (4, 3, 2 ,1), in practice is (None, 56, 300, 1), 
params = [
            [[['a0'], ['b0']], [['d0'], ['e0']], [['f0'], ['g0']]], 
            [[['a1'], ['b1']], [['d1'], ['e1']], [['f1'], ['g1']]],
            [[['a2'], ['b2']], [['d2'], ['e2']], [['f2'], ['g2']]],
            [[['a3'], ['b3']], [['d3'], ['e3']], [['f3'], ['g3']]],
        ]

# ind's shape is (4, 2) (In practice is (None, 15)), 
# so I wanna get output whose's shape is (4, 2, 2, 1), (In practice is (None, 15, 300, 1))
ind = [[1, 0], [0, 2], [2, 0], [2, 1]]
#ouput = [
# [[['d0'], ['e0']], [['a0'], ['b0']]],
# [[['a1'], ['b1']], [['f1'], ['g1']]],
# [[['f2'], ['g2']], [['a2'], ['b2']]],
# [[['f3'], ['g3']], [['d3'], ['e3']]]
#]

with tf.variable_scope('gather') as scope:
    tf_par = tf.constant(params)
    tf_ind = tf.constant(ind)
    res = tf.gather_nd(tf_par, tf_ind)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    print sess.run(res)
    print res

2 个答案:

答案 0 :(得分:3)

使用x沿第二维切片ind,即切片

  • 形状x的张量(d0, d1, d2,...)d0可能是None
  • 具有形状ind的索引(d0, n1)的张量,
  • 获取形状y的张量(d0, n1, d2, ...)

您可以使用tf.gather_ndtf.shape在运行时获取形状:

ind_shape = tf.shape(ind)
ndind = tf.stack([tf.tile(tf.range(ind_shape[0])[:, None], [1, ind_shape[1]]),
                  ind], axis=-1)
y = tf.gather_nd(x, ndind)

答案 1 :(得分:0)

对于您认为的结果,您应该使用:

ind = [[0, 1], [0, 0], [1, 0], [1, 2], [2, 2], [2, 0], [3, 2], [3, 1]]

<强>更新

您可以使用此代码获取所需内容,并使用当前输入:

with tf.variable_scope('gather') as scope:
    tf_par = tf.constant(params)
    tf_ind = tf.constant(ind)

    tf_par_shape = tf.shape(tf_par)
    tf_ind_shape = tf.shape(tf_ind)
    tf_r = tf.div(tf.range(0, tf_ind_shape[0] * tf_ind_shape[1]), tf_ind_shape[1])
    tf_r = tf.expand_dims(tf_r, 1)
    tf_ind = tf.expand_dims(tf.reshape(tf_ind, shape = [-1]), 1)
    tf_ind = tf.concat([tf_r, tf_ind], axis=1)

    res = tf.gather_nd(tf_par, tf_ind)
    res = tf.reshape(res, shape = (-1, tf_ind_shape[1], tf_par_shape[2], tf_par_shape[3]))