如何初始化Pytorch BatchNorm2d的均值和方差?

时间:2019-10-11 05:48:21

标签: tensorflow pytorch

我正在将TensorFlow模型转换为Pytorch。我想使用TensorFlow模型初始化BatchNorm2d的均值和方差。 我是这样做的:

bn.running_mean = torch.nn.Parameter(torch.Tensor(TF_param))

我收到此错误:

RuntimeError: the derivative for 'running_mean' is not implemented

但是适用于bn.weightbn.bias。有什么方法可以使用我预先训练的Tensorflow模型初始化均值和方差吗?在Pytorch中是否有类似moving_mean_initializermoving_variance_initializer的东西?

谢谢!

1 个答案:

答案 0 :(得分:1)

批处理规范层的运行平​​均值和方差不是nn.Parameters,而是该层的buffer

我认为您只需分配一个torch.tensor,而无需在其周围包裹nn.Parameter