我正在尝试在Pytorch中计算方差,但无法在多轴上进行。
我在Tensorflow中完成了类似的操作,但是由于torch.var函数将int作为尺寸而不是轴,因此无法在Pytorch上完成。 下面的代码是频道的最后一个代码,我希望轴= [2,3]
Lambda(lambda x: tf.nn.moments(x, axes=[1, 2]))
例如,如果input_dims =(5,10,25,25),则output_dims应该为(5,10,1,1)。
答案 0 :(得分:1)
您可以做的一件事是在应用tensor.view()
方法之前,使用var()
将要计算方差的所有维度展平为一个维度:
torch.var(x.view(x.shape[0], x.shape[1], 1, -1,), dim=3, keepdim=True)
我使用keepdim=True
保留了计算方差的尺寸,以获得所需的输出形状。