我需要连接一长串小的张量。每个小张量都是给定(非常简单)的常数矩阵的一个切片。这是代码:
max_node, counter = 0, 0
batch_size, n_days = (1000, 10)
n_interactions_in = torch.randint(low=100,high=200,size=(batch_size,n_days), dtype=torch.long)
max_interactions = n_interactions_in.max()
delay_table = torch.arange(n_days, device=device, dtype=torch.float).expand([max_interactions, n_days]).t().contiguous()
delay_table = n_days - delay_table - 1
edge_delay_buf = []
for b in range(batch_size):
delay_vec = [delay_table[d, :n_interactions_in[b, d]] for d in range(n_days)]
edge_delay_buf.append(torch.cat(delay_vec))
res = torch.cat(edge_delay_buf)
这需要很多时间。有没有一种方法可以有效地简化edge_delay_buf中每个元素的创建? 我尝试了多种变体,例如用列表串联替换for循环,其中结果是列表列表,然后展平列表,并在展平列表上应用torch.cat。但是,它并没有太大改善。由于某种原因,切片操作花费的时间太长。
有没有一种方法可以使切片更快?有没有办法使循环更有效/并行?
注意:虽然在此示例中使用割炬,但我也可以使用numpy。 注意2:对于在其他论坛上的重复帖子,我深表歉意。
答案 0 :(得分:0)
首先通过列表附加替换内部串联,最后只进行单个串联,这应该快得多。
max_node, counter = 0, 0
batch_size, n_days = (1000, 10)
n_interactions_in = torch.randint(low=100,high=200,size=(batch_size,n_days), dtype=torch.long)
max_interactions = n_interactions_in.max()
delay_table = torch.arange(n_days, device=device, dtype=torch.float).expand([max_interactions, n_days]).t().contiguous()
delay_table = n_days - delay_table - 1
edge_delay_buf = []
for b in range(batch_size):
delay_vec = [delay_table[d, :n_interactions_in[b, d]] for d in range(n_days)]
edge_delay_buf += delay_vec
res = torch.cat(edge_delay_buf)
然后,如果仍然不够快,则可以通过一次提取所有索引来提高效率。让我们看一下,您有一个形状为[N,M]的矩阵A,实际上可以提取出一些进行A [B,C]的元素,其中B是长度为K的向量,而C是形状为[L,K]的矩阵]。也许它可以满足您的需求。