假设我的张量为A,形状为[batch_size X长度X 1024]
我要执行以下操作:
对于批处理中的i元素,我想按其位置移动嵌入的所有“长度”元素的(嵌入1024)。
例如,向量A [0,0,:]应该保持不变,并且A [0,1,:]应该移位(或滚动)1,而A [0,15,:]应该移位15 。
这是针对批次中的所有元素的。
到目前为止,我是通过for循环完成的,但是效率不高
下面是我的for循环代码:
x = # [batchsize , length , 1024]
new_embedding = []
llist = []
batch_size = x.shape[0]
seq_len = x.shape[1]
for sample in range(batch_size):
for token in range(seq_len):
orig = x[sample , token , : ]
new_embedding.append(torch.roll(orig , token , 0))
llist.append(torch.stack(new_embedding , 0))
new_embedding = []
x = torch.stack(llist , 0)