我有一个到lstm的输入张量,它的形状是批处理 seq_len emb,我想用另一个形状为batch * seq_len
的张量沿着这个seq_len dim排序此男高音。例如:
有
[ [ [0,0]
[1,1]
[2,2] ]
[ [0,0]
[1,1]
[2,2] ]
]
,索引为
[[0,1,2]
[2,1,0] ]
希望输出为
[ [ [0,0]
[1,1]
[2,2] ]
[ [2,2]
[1,1]
[0,0] ]
]
pytorch中是否有任何花哨的操作可以使其快速运行? 谢谢