pytorch张量切片和内存使用

时间:2020-05-22 21:27:08

标签: pytorch

import torch
T = torch.FloatTensor(range(0,10 ** 6)) # 1M

#case 1:
torch.save(T, 'junk.pt')
# results in a 4 MB file size

#case 2:
torch.save(T[-20:], 'junk2.pt')
# results in a 4 MB file size

#case 3:
torch.save(torch.FloatTensor(T[-20:]), 'junk3.pt')
# results in a 4 MB file size

#case 4:
torch.save(torch.FloatTensor(T[-20:].tolist()), 'junk4.pt')
# results in a 405 Bytes file size

我的问题是:

(i)在情况3中,当我们创建一个新的张量时,生成的文件大小似乎令人惊讶。为什么这个新的张量不只是切片?

(ii)情况4是仅保存张量的一部分(切片)的最佳方法吗?

(iii)更一般而言,如果我想通过删除其值的前一半来“修剪”一个非常大的一维张量以节省内存,我是否必须像情况4一样继续进行操作,或者在那里一种不涉及创建python列表的更直接,计算成本更低的方法。

1 个答案:

答案 0 :(得分:1)

(i)在情况3中,当我们创建一个新的张量时,生成的文件大小似乎令人惊讶。为什么这个新张量不只是切片?

切片将创建张量视图,该图共享基础数据,但包含有关用于可见数据的内存偏移的信息。这样避免了必须频繁复制数据的情况,从而使许多操作效率更高。有关受影响的操作的列表,请参见PyTorch - Tensor Views

您正在处理基础数据很重要的少数情况之一。要保存张量,需要保存基础数据,否则偏移将不再有效。

torch.FloatTensor不会创建张量的副本(如果没有必要)。您可以验证其基础数据仍然相同(它们具有完全相同的内存位置):

torch.FloatTensor(T[-20:]).storage().data_ptr() == T.storage().data_ptr()
# => True

(ii)情况4是仅保存张量的一部分(切片)的最佳方法吗?

(iii)更一般而言,如果我想通过删除其值的前一半来“修剪”一个非常大的一维张量以节省内存,我是否必须像情况4一样继续进行操作,或者在那里一种更直接,更省钱的方式,无需创建python列表。

您很可能无法复制切片的数据,但是至少可以避免使用torch.Tensor.clone来从切片创建Python列表并从列表创建新的张量:

torch.save(T[-20:].clone(), 'junk5.pt')