我正在Google colab上使用torch_xla
我的网是这样的简单网:
class EmbeddingNet(nn.Module):
def __init__(self, n_users, n_movies,
n_factors=50, embedding_dropout=0.02,
dropouts=0.2):
super().__init__()
self.u = nn.Embedding(n_users, n_factors)
self.m = nn.Embedding(n_movies, n_factors)
self.drop = nn.Dropout(embedding_dropout)
self.fc = nn.Linear(n_factors, 1)
def forward(self, x):
x = torch.cat([self.u(x[:,0])*self.m(x[:,1])],1)
x = self.drop(x)
x = torch.relu(self.fc(x))
return x
初始化网络:
net = EmbeddingNet(
n_users=n, n_movies=m,
n_factors=50)
print(net)
EmbeddingNet(
(u): Embedding(138493, 50)
(m): Embedding(26744, 50)
(drop): Dropout(p=0.02)
(fc): Linear(in_features=50, out_features=1, bias=True)
)
然后我尝试将其转换为xla模型:
devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla_py.xla_model as xm
xla_model = xm.XlaModel(
net, [inputs],
loss_fn=F.mse_loss,
target=target,
num_cores=8,
devices=devices)
然后我得到了错误:
/usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in __init__(self, model, inputs, target, loss_fn, num_cores, devices, loader_prefetch, full_conv_precision)
496 devices=devices,
497 input_gradients=loss_output_grads,
--> 498 full_conv_precision=full_conv_precision)
499 else:
500 self._xla_model, self._traced_model = create_xla_model(
/usr/local/lib/python3.6/dist-packages/torch_xla_py/xla_model.py in create_xla_model(model, inputs, num_cores, devices, input_gradients, full_conv_precision)
235 if input_gradients is not None:
236 xla_model.set_input_gradients(input_gradients)
--> 237 xla_model(*inputs_xla)
238 return xla_model, traced_model
239
RuntimeError: differentiation of aten::mse_loss is not supported, or it is missing necessary type information
我尝试了l1_loss,获得了相同的结果。
我还尝试了另一种初始化方式:
import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
devices = [':{}'.format(n) for n in range(0, 8)]
inputs = torch.zeros(2000,2).long()
target = torch.zeros(2000, dtype=torch.float)
import torch_xla
traced_model = torch.jit.trace(net, (inputs, target))
xla_model = torch_xla._XLAC.XlaModule(traced_model)
output_xla = xla_model((torch_xla._XLAC.XLATensor(inputs), torch_xla._XLAC.XLATensor(target)))
它返回相同的错误:不支持aten :: mse_loss的区分,或者缺少必要的类型信息。