我试图将pytorch代码转换为tensorflow。所以我想知道tensorflow中是否有一个等效的pytorch函数,名为“ index_select”
答案 0 :(得分:1)
我还没有发现类似的api可以直接实现它,但是我们可以使用tf.slice
来实现它。
def tf_index_select(input_, dim, indices):
"""
input_(tensor): input tensor
dim(int): dimension
indices(list): selected indices list
"""
shape = input_.get_shape().as_list()
if dim == -1:
dim = len(shape)-1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0]*len(shape)
begin[dim] = idx
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
以下是显示等效性的示例。
import tensorflow as tf
import torch
import numpy as np
a = np.arange(2*3*4).reshape(2,3,4)
dim = 1
indices = [0,2]
# array([[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]])
# pytorch
res = torch.tensor(a).index_select(dim, torch.tensor(indices))
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])
# tensorflow
res = tf_index_select(tf.constant(a), dim, indices)
# tensor([[[ 0, 1, 2, 3],
# [ 8, 9, 10, 11]],
# [[12, 13, 14, 15],
# [20, 21, 22, 23]]])