我正在将TensorFlow模型转换为Pytorch。我想使用TensorFlow模型初始化BatchNorm2d的均值和方差。 我是这样做的:
bn.running_mean = torch.nn.Parameter(torch.Tensor(TF_param))
我收到此错误:
RuntimeError: the derivative for 'running_mean' is not implemented
但是适用于bn.weight
和bn.bias
。有什么方法可以使用我预先训练的Tensorflow模型初始化均值和方差吗?在Pytorch中是否有类似moving_mean_initializer
和moving_variance_initializer
的东西?
谢谢!
答案 0 :(得分:1)
批处理规范层的运行平均值和方差不是nn.Parameters
,而是该层的buffer。
我认为您只需分配一个torch.tensor
,而无需在其周围包裹nn.Parameter
。