Pytorch中是否有“张量”操作或函数,其功能类似于OpenCV中的cv2.dilate?

时间:2019-05-21 09:55:44

标签: python opencv pytorch

我通过网络建立了几个口罩。这些掩码存储在torch.tensor变量中。我想在cv2.dilate的每个频道上进行类似tensor的操作。

我知道有一种方法可以将tensor转换为numpy.ndarray,然后使用cv2.dilate循环将for应用于每个通道。但是由于大约有32个通道,此方法可能会减慢网络中的转发操作。

3 个答案:

答案 0 :(得分:2)

我认为膨胀实际上是在火炬中进行conv2d操作。参见下面的代码

import cv2
import numpy as np
import torch

im = np.array([ [0, 0, 0, 0, 0],
                [0, 1, 0, 0, 0],
                [0, 1, 1, 0, 0],
                [0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0] ], dtype=np.float32)
kernel = np.array([ [1, 1, 1],
                    [1, 1, 1],
                    [1, 1, 1] ], dtype=np.float32)
print(cv2.dilate(im, kernel))
# [[1. 1. 1. 0. 0.]
#  [1. 1. 1. 1. 0.]
#  [1. 1. 1. 1. 1.]
#  [1. 1. 1. 1. 1.]
#  [0. 0. 1. 1. 1.]]
im_tensor = torch.Tensor(np.expand_dims(np.expand_dims(im, 0), 0)) # size:(1, 1, 5, 5)
kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3)
torch_result = torch.clamp(torch.nn.functional.conv2d(im_tensor, kernel_tensor, padding=(1, 1)), 0, 1)
print(torch_result)
# tensor([[[[1., 1., 1., 0., 0.],
#           [1., 1., 1., 1., 0.],
#           [1., 1., 1., 1., 1.],
#           [1., 1., 1., 1., 1.],
#           [0., 0., 1., 1., 1.]]]])

答案 1 :(得分:2)

PyTorch 张量的膨胀和腐蚀函数在 Kornia 库 https://kornia.readthedocs.io/en/latest/morphology.html 中实现。

答案 2 :(得分:2)

编辑:

我为此创建了一个库;该库名为 nnMorpho,可以通过 pip install nnMorpho 安装。 我使用的原则是下面描述的原则(:使用 PyTorch 的展开功能)。目前该库处于早期阶段(仅实现了基本操作),但我会尝试更新它以包含更多种类的操作和参数。

Dilation 和 convd2d 不一样

Dilation 和 convd2d 完全不同:粗略地说,convd2d 执行线性过滤器(这意味着它在像素周围进行计算求和),而膨胀执行非线性过滤器(取像素周围的最大值)。< /p>

Kornia 无法正常工作

我检查了 kornia 的代码,它没有按预期工作。我在底部展示了几张图片来证明这一点。我与 the morphology library from scipy 进行了比较,因为它比 cv2 更受限制(我无法解决运行代码的错误)。特别是,我使用了 greyscale dilation

一种在 PyTorch 中进行形态学的方法

在 PyTorch 中有一种方法可以进行数学形态学操作。在处理膨胀和侵蚀时,您面临的主要问题是您必须考虑每个像素的邻域来计算最大值(如果处理灰度结构元素,则可能计算和和差)。这个问题是由 PyTorch 的函数 unfold 解决的;它目前仅支持批量图像类张量(即:尺寸为 (B,C,H,W) 的 4D 张量),但这对于您的需求来说应该不是问题。其余的只是正常操作。

我加入了一个做膨胀的代码(侵蚀是类似的)和例子:

import numpy as np
import torch
from torch.nn import functional as f
from scipy.ndimage import grey_dilation as dilation_scipy
from kornia.morphology import dilation as dilation_kornia
import matplotlib.pyplot as plt


# Definition of the dilation using PyTorch
def dilation_pytorch(image, strel, origin=(0, 0), border_value=0):
    # first pad the image to have correct unfolding; here is where the origins is used
    image_pad = f.pad(image, [origin[0], strel.shape[0] - origin[0] - 1, origin[1], strel.shape[1] - origin[1] - 1], mode='constant', value=border_value)
    # Unfold the image to be able to perform operation on neighborhoods
    image_unfold = f.unfold(image_pad.unsqueeze(0).unsqueeze(0), kernel_size=strel.shape)
    # Flatten the structural element since its two dimensions have been flatten when unfolding
    strel_flatten = torch.flatten(strel).unsqueeze(0).unsqueeze(-1)
    # Perform the greyscale operation; sum would be replaced by rest if you want erosion
    sums = image_unfold + strel_flatten
    # Take maximum over the neighborhood
    result, _ = sums.max(dim=1)
    # Reshape the image to recover initial shape
    return torch.reshape(result, image.shape)


# Test image
image = np.zeros((7, 7), dtype=int)
image[2:5, 2:5] = 1
image[4, 4] = 2
image[2, 3] = 3

plt.figure()
plt.imshow(image, cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Original image')

# Structural element square 3x3
strel = np.ones((3, 3))

# Origin of the structural element
origin = (1, 1)

# Scipy
dilated_image_scipy = dilation_scipy(image, size=(3, 3), structure=strel)

plt.figure()
plt.imshow(dilated_image_scipy, cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Dilated image - Scipy')

# PyTorch
image_tensor = torch.tensor(image, dtype=torch.float)
strel_tensor = torch.tensor(strel, dtype=torch.float)
dilated_image_pytorch = dilation_pytorch(image_tensor, strel_tensor, origin=origin, border_value=-1000)

plt.figure()
plt.imshow(dilated_image_pytorch.cpu().numpy(), cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Dilated image - PyTorch')

# Kornia
dilated_image_kornia = dilation_kornia(image_tensor.unsqueeze(0).unsqueeze(0), strel_tensor)
plt.figure()
plt.imshow(dilated_image_kornia.cpu().numpy()[0, 0, :, :], cmap='Greys', vmin=image.min(), vmax=image.max(), origin='lower')
plt.title('Dilated image - Kornia')

plt.show()

The original image proposed in Scipy documentation

dilated image by scipy

dilated image by pytorch

dilated image by kornia

关于起源的考虑

原点是膨胀和侵蚀的关键参数。它操作移动图像。如果你想让你的图像不移动,你应该把它放在中间(这意味着有一个奇数大小的结构元素)。我尝试在 scipy 中使用它,但效果不佳,因为它在所有维度上都是相同的(在处理非方形结构元素时会出现问题)。我展示的代码正确地考虑了原点。