为轴添加额外的尺寸

时间:2019-07-09 12:28:32

标签: python-3.x pytorch tensorboard

我有一批[5,1,100,100]batch_size x dims x ht x wd)形状的分割蒙版,我必须在tensorboardX中以RGB图像批次[5,3,100,100]显示这些分割蒙版。我想在分段蒙版的第二个轴上添加两个虚拟尺寸以使其为[5,3,100,100],所以当我将其传递给torch.utils.make_grid时不会出现任何尺寸不匹配错误。我已经尝试过unsqueezeexpandview,但是我做不到。有什么建议吗?

2 个答案:

答案 0 :(得分:3)

您可以使用expandrepeatrepeat_interleave

import torch

x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)

print(x1_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape)  # torch.Size([5, 3, 100, 100])

请注意,如文档所述:

  

扩展张量不会分配新的内存,而只会在现有张量上创建一个新视图,其中通过将步幅设置为0将大小为1的维扩展为更大的大小。 >任何大小为1的维度都可以扩展为任意值,而无需分配新内存。

  

不同于expand()此函数复制张量的数据

答案 1 :(得分:0)

Expand是一种我不断告诉自己不要阅读文档的方法,该文档的内容为:

  

扩展张量不会分配新的内存,而只会在现有张量上创建新视图

由于在PyTorch中没有类似视图的内容,至少我从来没有将它们视为对象,因此无法创建它们。唯一的是:大步向前。

展开也可以缩小。

t21 = torch.rand(2,1)
print(t)
print(t.shape)
print(t.stride())

t25 = t.expand(-1,5)
print(t25.shape)
print(t25)
print(t25.stride())

t123 = t.expand(1,-1,3)
print(t123.shape)
print(t123)
print(t123.stride())

# tensor([[0.1353],
#         [0.5809]])
# torch.Size([2, 1])
# (1, 1)
# torch.Size([2, 5])
# tensor([[0.1353, 0.1353, 0.1353, 0.1353, 0.1353],
#         [0.5809, 0.5809, 0.5809, 0.5809, 0.5809]])
# (1, 0)
# torch.Size([1, 2, 3])
# tensor([[[0.1353, 0.1353, 0.1353],
#          [0.5809, 0.5809, 0.5809]]])
# (2, 1, 0)