尺寸超出范围(预计在[-2,1]范围内,但得到2)

时间:2018-06-09 17:16:38

标签: pytorch

为什么弹出以下错误?应该在这个范围内,为什么? -2维意味着什么?

RuntimeError: dimension out of range (expected to be in range of [-2, 1], but got 2)

此代码将产生错误

import torch 

torch.bmm(torch.randn(1000, 784) , torch.randn(784, 10))

2 个答案:

答案 0 :(得分:1)

torch.mm

  

执行矩阵mat1和mat2的矩阵乘法。

     

如果mat1是(n×m)张量,mat2是(m×p)张量,out将是(n×p)       张量。

torch.bmm

  

执行存储在batch1中的矩阵的批处理矩阵 - 矩阵乘积   和batch2。

     

batch1和batch2必须是3-D张量,每个张量包含相同的数字   矩阵。

     

如果batch1是(b×n×m)张量,则batch2是(b×m×p)张量,out将是   a(b×n×p)张量。

以下代码段有效。

from pets import Dog, Cat

class PetCreator:
  @classmethod
  def __call__(cls, pet_type):
    if pet_type == "cat": return Cat()
    elif pet_type == "dog": return Dog()
    else: raise SomeError

def pet_creator(pet_type):
  if pet_type == "cat": return Cat()
  elif pet_type == "dog": return Dog()
  else: raise SomeError

if __name__ == "__main__":
  fav_pet_type = input() # "cat"
  my_cat = pet_creator(fav_pet_type) #this?
  my_cat = PetCreator(fav_pet_type) #or that?

答案 1 :(得分:0)

方法$textcolor实现批量矩阵 - 矩阵产品。对于普通的矩阵 - 矩阵产品,您需要两个具有两个2D矩阵才能创建产品。

使用torch.bmm您可以创建产品甚至批量生成,但当然您需要包含批量维度,因此您需要两个输入3维矩阵。

关于torch.bmm中使用的尺寸的方式:

  

如果 batch1 (b×n×m)张量,    batch2 (b×m×p)张量,输出为(b×n×p)张量。

https://pytorch.org/docs/master/torch.html#torch.bmm