我有一批[5,1,100,100]
(batch_size x dims x ht x wd
)形状的分割蒙版,我必须在tensorboardX中以RGB图像批次[5,3,100,100]
显示这些分割蒙版。我想在分段蒙版的第二个轴上添加两个虚拟尺寸以使其为[5,3,100,100]
,所以当我将其传递给torch.utils.make_grid
时不会出现任何尺寸不匹配错误。我已经尝试过unsqueeze
,expand
和view
,但是我做不到。有什么建议吗?
答案 0 :(得分:3)
您可以使用expand
,repeat
或repeat_interleave
:
import torch
x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)
print(x1_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape) # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape) # torch.Size([5, 3, 100, 100])
请注意,如文档所述:
扩展张量不会分配新的内存,而只会在现有张量上创建一个新视图,其中通过将步幅设置为0将大小为1的维扩展为更大的大小。 >任何大小为1的维度都可以扩展为任意值,而无需分配新内存。
不同于
expand()
,此函数复制张量的数据。
答案 1 :(得分:0)
Expand是一种我不断告诉自己不要阅读文档的方法,该文档的内容为:
扩展张量不会分配新的内存,而只会在现有张量上创建新视图
由于在PyTorch中没有类似视图的内容,至少我从来没有将它们视为对象,因此无法创建它们。唯一的是:大步向前。
展开也可以缩小。
t21 = torch.rand(2,1)
print(t)
print(t.shape)
print(t.stride())
t25 = t.expand(-1,5)
print(t25.shape)
print(t25)
print(t25.stride())
t123 = t.expand(1,-1,3)
print(t123.shape)
print(t123)
print(t123.stride())
# tensor([[0.1353],
# [0.5809]])
# torch.Size([2, 1])
# (1, 1)
# torch.Size([2, 5])
# tensor([[0.1353, 0.1353, 0.1353, 0.1353, 0.1353],
# [0.5809, 0.5809, 0.5809, 0.5809, 0.5809]])
# (1, 0)
# torch.Size([1, 2, 3])
# tensor([[[0.1353, 0.1353, 0.1353],
# [0.5809, 0.5809, 0.5809]]])
# (2, 1, 0)