如何为torch.cat初始化Tensor

时间:2020-04-13 19:24:13

标签: pytorch

import torch

#Y_pred = ?

for xi in X_iter:
    y_pred = net(xi).argmax(dim=1)
    Y_pred = torch.cat([Y_pred, y_pred])

您如何初始化该张量,或者有更好的方法编写它?

1 个答案:

答案 0 :(得分:0)

您可以改为这样做:

Y_pred = torch.cat([net(xi).argmax(dim=1) for xi in X_iter])