不支持aten :: mse_loss的区分,或者缺少必要的类型信息

时间:2019-05-14 07:53:08

标签: google-cloud-platform pytorch google-colaboratory google-cloud-tpu tpu

我正在Google colab上使用torch_xla

我的网是这样的简单网:

class EmbeddingNet(nn.Module):


 def __init__(self, n_users, n_movies,
             n_factors=50, embedding_dropout=0.02, 
            dropouts=0.2):

    super().__init__()


    self.u = nn.Embedding(n_users, n_factors)
    self.m = nn.Embedding(n_movies, n_factors)
    self.drop = nn.Dropout(embedding_dropout)

    self.fc = nn.Linear(n_factors, 1)


 def forward(self, x):
    x = torch.cat([self.u(x[:,0])*self.m(x[:,1])],1)
    x = self.drop(x)
    x = torch.relu(self.fc(x))
    return x

初始化网络:

net = EmbeddingNet(
n_users=n, n_movies=m, 
n_factors=50)

print(net)

EmbeddingNet(
  (u): Embedding(138493, 50)
  (m): Embedding(26744, 50)
  (drop): Dropout(p=0.02)
  (fc): Linear(in_features=50, out_features=1, bias=True)
)

然后我尝试将其转换为xla模型:

devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla_py.xla_model as xm
xla_model = xm.XlaModel(
      net, [inputs],
      loss_fn=F.mse_loss,
      target=target,
      num_cores=8,
      devices=devices)

然后我得到了错误:

    /usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in __init__(self, model, inputs, target, loss_fn, num_cores, devices, loader_prefetch, full_conv_precision)
    496           devices=devices,
    497           input_gradients=loss_output_grads,
--> 498           full_conv_precision=full_conv_precision)
    499     else:
    500       self._xla_model, self._traced_model = create_xla_model(

/usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in create_xla_model(model, inputs, num_cores, devices, input_gradients, full_conv_precision)
    235   if input_gradients is not None:
    236     xla_model.set_input_gradients(input_gradients)
--> 237   xla_model(*inputs_xla)
    238   return xla_model, traced_model
    239 

RuntimeError: differentiation of aten::mse_loss is not supported, or it is missing necessary type information

我尝试了l1_loss,获得了相同的结果。

我还尝试了另一种初始化方式:

import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
output_xla = xla_model((torch_xla._XLAC.XLATensor(inputs), torch_xla._XLAC.XLATensor(target)))

它返回相同的错误:不支持aten :: mse_loss的区分,或者缺少必要的类型信息。

0 个答案:

没有答案