在空间轴上计算火炬的方差

时间:2019-02-17 10:49:53

标签: tensorflow pytorch variance

我正在尝试在Pytorch中计算方差,但无法在多轴上进行。

我在Tensorflow中完成了类似的操作,但是由于torch.var函数将int作为尺寸而不是轴,因此无法在Pytorch上完成。 下面的代码是频道的最后一个代码,我希望轴= [2,3]

Lambda(lambda x: tf.nn.moments(x, axes=[1, 2]))

例如,如果input_dims =(5,10,25,25),则output_dims应该为(5,10,1,1)。

1 个答案:

答案 0 :(得分:1)

您可以做的一件事是在应用tensor.view()方法之前,使用var()将要计算方差的所有维度展平为一个维度:

torch.var(x.view(x.shape[0], x.shape[1], 1, -1,), dim=3, keepdim=True)

我使用keepdim=True保留了计算方差的尺寸,以获得所需的输出形状。