我正在尝试将功能应用于3d火炬张量,而将功能应用于通过3d火炬张量的轴1读取的2d张量。
例如,我有一个形状为(51, 128, 20100)
的火炬张量(名称为autoencode_logprob
的变量),并且函数(rawid2sentence)在形状为(51, 20100)
的输入上运行。 / p>
现在,我编写了可与naive for循环一起运行的代码,并与range(128)一一循环。
但是,它太慢了。以下是重要的代码部分。
autoencode_logprobs是3d张量,我需要沿其第二个轴应用rawids2sentence
函数。对它进行矢量化有帮助吗?
for i in range(128):
output_sent = self.dictionary.rawids2sentence(
autoencode_logprobs[:, i].max(1)[
1].data.cpu().numpy(),
oov_dicts[i],
)
output_sent_encoding = ifst_model.encode([output_sent])
答案 0 :(得分:1)
由于我不知道rawids2sentence
或encode
函数的作用,因此我可以帮助您进行最大操作。
在以下语句中,
autoencode_logprobs[:, i].max(1)[1]
您为每个dim=1
张量确定沿着51 x 20100
的最大值的索引。因此,输出是大小为51
的向量。
您可以在形状为51 x 128 x 20100
的完整张量中执行相同的操作,并以128 x 51
张量获得输出。
autoencode_logprobs.transpose(0, 1).max(2)[1] # 128 x 51
因此,如果您的rawids2sentence
或encode
方法可以处理批量输入,则上述更改应该对您有效,而不会产生任何循环。