我有一个非常大的4维数据集,其中最后两个维是在图像检测器上拍摄的图像。该检测器上的某些像素不起作用,这些无效像素的值为0。我想在后期处理中将这些像素的值设置为其相邻像素的中值。数据集的范围从8 GB到可能的TB大小,所以我想使用dask数组,因为我可以将坏点去除与其他处理步骤链接在一起。
查找坏点很容易,但是我不确定如何最好地获得邻居的中值。
最小示例:
import numpy as np
import dask.array as da
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0
# Some kind of processing
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array
所以我的问题是:我怎么得到dead_pixel_values_array
?还是其他一些巧妙的方法来去除坏点?
答案 0 :(得分:2)
这是dask数组的区域中位数的实现:https://dask-ndfilters.readthedocs.io/en/latest/dask_ndfilters.html#dask_ndfilters.median_filter
如果您自己想要更通用的内容,或者不想安装dask_ndfilters
,则应阅读map_overlap
,该书可让您访问来自相邻块的每个块周围的数据,并且因此可以考虑您要进行的计算。
答案 1 :(得分:0)
我设法通过首先将图像尺寸(最后两个)重新分块为一个块,然后使用map_blocks
来做到这一点。 map_overlay
适用于小型数据集,但对于较大的数据集,内存使用量远大于可用内存。
import numpy as np
import dask.array as da
import matplotlib.pyplot as plt
def remove_dead_pixels(data, dead_pixels):
dif0 = np.roll(data, shift=1, axis=-2) * dead_pixels
dif1 = np.roll(data, shift=-1, axis=-2) * dead_pixels
dif2 = np.roll(data, shift=1, axis=-1) * dead_pixels
dif3 = np.roll(data, shift=-1, axis=-1) * dead_pixels
output_data = np.median(np.stack([dif0, dif1, dif2, dif3], axis=-1),
axis=-1, overwrite_input=True, keepdims=False)
return output_data
# Making artificial data
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array = dask_array.rechunk((5, 5, 20, 20))
# Finding dead pixels
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0
# Getting replacement values
dead_pixel_values_array = da.map_blocks(
remove_dead_pixels, dask_array, dead_pixels, dtype=dask_array.dtype,
chunks=dask_array.chunks)
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array
# Plotting result
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(dask_array.sum(axis=(0, 1)).compute())
axarr[1].imshow(dask_array_without_dead_pixels.sum(axis=(0, 1)).compute())
fig.tight_layout()
fig.savefig("image.jpg")