使用pytorch从大图像进行补丁分类

时间:2019-03-29 15:43:35

标签: python deep-learning pytorch tensor

我正在使用以下功能通过训练有素的pytorch模型对图像进行分类。对于较小的输入图像,它可以正常工作。

def test(net, STRIDE-16, BATCH_SIZE=20, WINDOW_SIZE= (256,256)):

# Use the network on the test image
img = (1 / 255 * np.asarray(io.imread("C:/bd/R1C1.tif"), dtype='float32'))


all_preds = []

# Switch the network to inference mode
net.eval()

pred = np.zeros(img.shape[:2] + (N_CLASSES,))


for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, 
STRIDE, WINDOW_SIZE)))):

# Build the tensor
image_patches = [np.copy(img[x:x + w, y:y + h]).transpose((2, 0, 1)) for 
x, y, w, h in coords]

image_patches = np.asarray(image_patches)
image_patches = Variable(torch.from_numpy(image_patches).cpu(), 
volatile=True)

 # Do the inference
 outs = net(image_patches)
 outs = outs.data.cpu().numpy()

 # Fill in the results array
 for out, (x, y, w, h) in zip(outs, coords):
     out = out.transpose((1, 2, 0))
     pred[x:x + w, y:y + h] += out
  del (outs)

pred = np.argmax(pred, axis=-1)

all_preds.append(pred)


return all_preds

当我加载大img时,说(40k x 40k)我得到MemoryError。如何避免此内存错误。也许通过使用一批较小的图像?如何有效实施?

0 个答案:

没有答案