使用dask数组去除不良像素

时间:2018-06-19 18:26:50

标签: python dask

我有一个非常大的4维数据集,其中最后两个维是在图像检测器上拍摄的图像。该检测器上的某些像素不起作用,这些无效像素的值为0。我想在后期处理中将这些像素的值设置为其相邻像素的中值。数据集的范围从8 GB到可能的TB大小,所以我想使用dask数组,因为我可以将坏点去除与其他处理步骤链接在一起。

查找坏点很容易,但是我不确定如何最好地获得邻居的中值。

最小示例:

import numpy as np
import dask.array as da

data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0

# Some kind of processing

dask_array_without_dead_pixels = dask_array + dead_pixel_values_array

所以我的问题是:我怎么得到dead_pixel_values_array?还是其他一些巧妙的方法来去除坏点?

2 个答案:

答案 0 :(得分:2)

这是dask数组的区域中位数的实现:https://dask-ndfilters.readthedocs.io/en/latest/dask_ndfilters.html#dask_ndfilters.median_filter

如果您自己想要更通用的内容,或者不想安装dask_ndfilters,则应阅读map_overlap,该书可让您访问来自相邻块的每个块周围的数据,并且因此可以考虑您要进行的计算。

答案 1 :(得分:0)

我设法通过首先将图像尺寸(最后两个)重新分块为一个块,然后使用map_blocks来做到这一点。 map_overlay适用于小型数据集,但对于较大的数据集,内存使用量远大于可用内存。

import numpy as np
import dask.array as da
import matplotlib.pyplot as plt

def remove_dead_pixels(data, dead_pixels):
    dif0 = np.roll(data, shift=1, axis=-2) * dead_pixels
    dif1 = np.roll(data, shift=-1, axis=-2) * dead_pixels
    dif2 = np.roll(data, shift=1, axis=-1) * dead_pixels
    dif3 = np.roll(data, shift=-1, axis=-1) * dead_pixels
    output_data = np.median(np.stack([dif0, dif1, dif2, dif3], axis=-1),
                            axis=-1, overwrite_input=True, keepdims=False)
    return output_data

# Making artificial data
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array = dask_array.rechunk((5, 5, 20, 20))

# Finding dead pixels
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0

# Getting replacement values
dead_pixel_values_array = da.map_blocks(
        remove_dead_pixels, dask_array, dead_pixels, dtype=dask_array.dtype,
        chunks=dask_array.chunks)
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array

# Plotting result
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(dask_array.sum(axis=(0, 1)).compute())
axarr[1].imshow(dask_array_without_dead_pixels.sum(axis=(0, 1)).compute())
fig.tight_layout()
fig.savefig("image.jpg")

Left image: data with two dead pixels. Right image: data with the dead pixels removed