如何获得火炬::张量形状

时间:2021-02-18 22:18:49

标签: c++ pytorch libtorch

如果我们<<一个torch::Tensor

#include <torch/script.h>
int main()
{
    torch::Tensor input_torch = torch::zeros({2, 3, 4});
    std::cout << input_torch << std::endl;
    return 0;
}

我们看到

(1,.,.) = 
  0  0  0  0
  0  0  0  0
  0  0  0  0

(2,.,.) = 
  0  0  0  0
  0  0  0  0
  0  0  0  0
[ CPUFloatType{2,3,4} ]

如何获得张量形状(即 2,3,4)?我在 https://pytorch.org/cppdocs/api/classat_1_1_tensor.html?highlight=tensor 中搜索了 API 调用,但没有找到。我搜索了 operator<< 重载代码,也没有找到。

1 个答案:

答案 0 :(得分:0)

您可以使用 torch::sizes() 方法

IntArrayRef sizes()

它相当于python中的形状。此外,您可以通过调用 torch::size(dim) 访问给定轴(维度)的特定大小。这两个函数都在您链接的 API 页面中