tensorflow中是否有一个等效的pytorch函数名为“ index_select”

时间:2019-10-19 14:39:35

标签: tensorflow pytorch

我试图将pytorch代码转换为tensorflow。所以我想知道tensorflow中是否有一个等效的pytorch函数,名为“ index_select”

1 个答案:

答案 0 :(得分:1)

我还没有发现类似的api可以直接实现它,但是我们可以使用tf.slice来实现它。

def tf_index_select(input_, dim, indices):
    """
    input_(tensor): input tensor
    dim(int): dimension
    indices(list): selected indices list
    """
    shape = input_.get_shape().as_list()
    if dim == -1:
        dim = len(shape)-1
    shape[dim] = 1

    tmp = []
    for idx in indices:
        begin = [0]*len(shape)
        begin[dim] = idx
        tmp.append(tf.slice(input_, begin, shape))
    res = tf.concat(tmp, axis=dim)

    return res

以下是显示等效性的示例。

import tensorflow as tf
import torch
import numpy as np

a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]],

#        [[12, 13, 14, 15],
#         [16, 17, 18, 19],
#         [20, 21, 22, 23]]])

# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])

# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0,  1,  2,  3],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [20, 21, 22, 23]]])