Pytorch中的地图功能

时间:2018-04-17 12:04:41

标签: python mapping pytorch

Pytorch有没有地图功能? (类似于python中的map。)

我需要将1xDxhxw张量变量映射到1x(9D)xhxw张量,以通过其8个邻居嵌入来增强每个像素的嵌入。 Pytorch中是否有任何功能可以让我有效地做到这一点?

我尝试用这种方式在Python中使用map:

n, d, h, w = embedding.size()
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)
embedding = map(lambda i, j, M: M[:, :, i-1:i+2, j-1:j+2], range(1, h), range(1, w), embedding)

但它不适用于w > 2h > 2

1 个答案:

答案 0 :(得分:0)

从你的问题来看,目前尚不清楚你想要完成什么。

请注意,PyTorch支持完整的Python,但您正在做的是在最后一行代码中创建一个map对象。以下内容应该适用于您的目的(?我猜测):

import torch
import torch.nn as nn
n, d, h, w = 20, 3, 32, 32
embedding = torch.randn(n, d, h, w)
padder = nn.ReflectionPad2d(padding=1)
embedding = padder(embedding)

new = [embedding[:,:, (i-1):(i+2), (j-1):(j+2)] for i, j in zip(range(1,h), range(1,w))]

但是请注意,有更优雅的方法来填充张量(例如torch.chunk()或使用卷积操作补丁(例如torch.nn.Conv2d