训练时在图像批处理数据上添加白噪声

时间:2019-02-11 08:49:31

标签: python numpy deep-learning pytorch noise

我正在尝试一个降噪模型,目标是打印出每批次clean / add_noise / model_output

我正在使用PyTorch DataLoader。每个图像都有shape = (256, 128)batch_size = 10,因此每个批次的大小为(10, 256, 128)。我想打印出每批的第一批数据,即batch_data[0]

每个图像都有shape = (256, 128)

我编写了一个添加噪声的函数,如下所示:

def add_noise(data, bs, target_snr, noise_type):

    if noise_type == 'white':
        noise = acoustics.generator.white(bs*256*128).reshape(bs, 256, 128)

    if noise_type == 'pink':
        noise = acoustics.generator.pink(bs*256*128).reshape(bs, 256, 128)


    print ('data shape = ', data.shape)

    average = np.mean(data)
    std = np.std(noise)
    current_snr = average/std

    noise = noise * (current_snr/ target_snr)
    data = data + noise

    return data 

但是,它保持如下所示的返回错误消息:

TypeError: mean() missing 3 required positional argument: "dim", "keepdim", "dtype"

我应该如何处理?

2 个答案:

答案 0 :(得分:0)

您的data的形状是什么? type(data)是什么?
您是否在Numpy函数中传递了DataLoader张量?

看看Numpy的mean()函数的documentation,其中还包含一些示例。

该函数将类似 array 的对象作为其输入(例如,可以是2d矩阵),因此均值不能立即明确定义。您是否需要计算行,列或矩阵中所有数据的均值?计算中使用的数据类型是什么?

在第一种情况下,您需要提供要展平数组的尺寸。在第二种情况下,它应该开箱即用,将Numpy设为“默认是计算平整数组的均值” ,但是由于您使用的是PyTorch的{​​{1}},它可能需要定义它们。

由于您的DataLoader似乎是数字,因此类似的操作应该有效

average

答案 1 :(得分:0)

作为原始帖子下的第一条评论。数据是PyTorch张量,而我使用的是Numpy方法。我尝试使用torch.mean()torch.std(),它可以正常工作。