如何在c ++ api中使用NDArray?

时间:2017-08-10 07:33:53

标签: c++ mxnet

我正在更改mxnet中的c ++示例。我不明白如何分配NDArray对象。甚至没有基本的文档,这非常令人沮丧。

我尝试分配一个NDArray,但是通过声明一个实例它似乎没有分配数据,只有当我用数据填充一个数组时。这是对的吗?

// this code snippet does not work     
  NDArray a = NDArray(Shape(10, 20), Context::cpu());
  const float *dat = a.GetData();
  float result = *dat; // read memory violation
  result = *(dat + 10);

// this code snippet works
  NDArray b = NDArray(Shape(10, 20), Context::cpu());
  a.SampleUniform(1.0, 2.0, &b);
  const float *dat2 = b.GetData();
  float result2 = *dat2; // works!!
  result2 = *(dat2 + 10); 

是否有人使用过c ++ API和更改网络?

2 个答案:

答案 0 :(得分:2)

第三个参数delay_alloc: https://github.com/apache/incubator-mxnet/blob/master/cpp-package/include/mxnet-cpp/ndarray.h#L144

设置为false,您的代码将起作用。

答案 1 :(得分:1)

问题发布后已经有一段时间了。如果有帮助的话,这就是我的答案。

您可以使用std::vector和存储来自矢量的数据的矩阵Shape来定义要设置的数据。

  std::vector<mx_float> v {1.23, 4.56, 7.89, 5.71};

  // populates v vector data in a matrix of 1 row and 4 columns
  // mxnet::cpp::NDArray nda_array {v, Shape{1,4}, m_ctx};

  // populates v vector data in a matrix of 2 rows and 2 columns
  // where v[3] and v[4] are populated in second row
  mxnet::cpp::NDArray nda_array {v, Shape{2,2}, m_ctx};

  assert(nda_array.Size() == 4);

  assert(nda_array.At(0, 0) == v[0]);
  assert(nda_array.At(0, 1) == v[1]);
  assert(nda_array.At(1, 0) == v[2]);
  assert(nda_array.At(1, 1) == v[3]);