使用整数张量对张量流中的张量进行索引

时间:2017-12-28 23:43:57

标签: python numpy tensorflow

我的问题类似于here,但不完全相同。我有两个张量

mu: (shape=(1000,1), dtype=np.float32)
p : (shape=(100,30), dtype=np.int64)

我想要的是创建一个新的张量

x : (shape=(100,30), dtype=np.float32)

这样

x[i,j] = mu[p[i,j]]

这可以使用高级索引

在numpy中完成
x = mu[p]

我尝试使用tf.gather_nd(mu, p)命令,但在我的情况下,我收到以下错误

*** ValueError: indices.shape[-1] must be <= params.rank, but saw indices shape: [100,30] and params shape: [1000] for 'GatherNd_2' (op: 'GatherNd') with input shapes: [1000], [100,30].

因此,为了使用它,我必须建立一个新的坐标张量。有没有更简单的方法来实现我想要的目标?

2 个答案:

答案 0 :(得分:2)

这是一个有效的解决方案:

tf.reshape(tf.gather(mu[:,0], tf.reshape(p, (-1,))), p.shape)

基本上它

  1. 将索引数组展平为1d,tf.reshape(p, (-1,));
  2. mu[:,0]收集元素(mu的第一列);
  3. 然后将其重塑为p的形状。
  4. 最小示例

    import tensorflow as tf
    tf.InteractiveSession()
    
    mu = tf.reshape(tf.multiply(tf.cast(tf.range(10), tf.float32), 0.1), (10, 1))
    mu.eval()
    #array([[ 0.        ],
    #       [ 0.1       ],
    #       [ 0.2       ],
    #       [ 0.30000001],
    #       [ 0.40000001],
    #       [ 0.5       ],
    #       [ 0.60000002],
    #       [ 0.69999999],
    #       [ 0.80000001],
    #       [ 0.90000004]], dtype=float32)
    
    p = tf.constant([[1,3],[2,4],[3,1]], dtype=tf.int64)
    
    tf.reshape(tf.gather(mu[:,0], tf.reshape(p, (-1,))), p.shape).eval()
    
    #array([[ 0.1       ,  0.30000001],
    #       [ 0.2       ,  0.40000001],
    #       [ 0.30000001,  0.1       ]], dtype=float32)
    

    使用gather_nd而不重塑的另外两个选项:

    tf.gather_nd(mu[:,0], tf.expand_dims(p, axis=-1)).eval()
    
    #array([[ 0.1       ,  0.30000001],
    #       [ 0.2       ,  0.40000001],
    #       [ 0.30000001,  0.1       ]], dtype=float32)
    
    tf.gather_nd(mu, tf.stack((p, tf.zeros_like(p)), axis=-1)).eval()
    
    #array([[ 0.1       ,  0.30000001],
    #       [ 0.2       ,  0.40000001],
    #       [ 0.30000001,  0.1       ]], dtype=float32)
    

答案 1 :(得分:0)

您可以使用tf.map_fn

$arr = (array_count_values([33, 32, 23, 33, 22, 23, 32, 33, 33]));
$new_arr = [];

foreach($arr as $k => $a) {
    $new_arr[number_format($k, 6)] = $a;
}

print_r($new_arr);

Array ( [33.000000] => 4 [32.000000] => 2 [23.000000] => 2 [22.000000] => 1 ) 充当在 x= tf.map_fn(lambda u: tf.gather(tf.squeeze(mu),u),p,dtype=mu.dtype) 的第一维上运行的循环,并且对于每个此类切片,它都适用map_fn