用于C ++的mxnet ndarray迭代器

时间:2018-02-12 09:59:47

标签: c++ iterator mxnet

我想在C ++中训练一个简单的分类器,非常像C++ mnist example,但是我的数据并没有存储在HD上,而是已经加载到内存中,比如mxnet NDArray。在Python中,为此目的,有一个方便的NDArrayIter,c.f。 Module tutorial

C ++有这样的NDArray迭代器吗?

浏览代码时,我发现可以从MXDataIterMXListDataIters读取所有可能的MXDataIterGetIterInfo

#include "mxnet-cpp/io.h"
using namespace std;
using namespace mxnet::cpp;

int main(int argc, char** argv) {
  Context ctx = Context::cpu();  // Use CPU

  mx_uint num_data_iter_creators;
  DataIterCreator *data_iter_creators = nullptr;

  int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators);
  CHECK_EQ(r, 0);
  cout << "num_data_iter_creators = " << num_data_iter_creators << endl;
  //output: num_data_iter_creators = 8

  const char *name;
  const char *description;
  mx_uint num_args;
  const char **arg_names;
  const char **arg_type_infos;
  const char **arg_descriptions;

  for (mx_uint i = 0; i < num_data_iter_creators; i++) {
      r = MXDataIterGetIterInfo(data_iter_creators[i], &name, &description,
                                &num_args, &arg_names, &arg_type_infos,
                                &arg_descriptions);
      CHECK_EQ(r, 0);
      cout << " i: " << i << ", name: " << name << endl;
  }

  MXNotifyShutdown();
  return 0;
}

产生八个MXDataIter()&#39;

num_data_iter_creators = 8
 i: 0, name: ImageDetRecordIter
 i: 1, name: CSVIter
 i: 2, name: ImageRecordIter_v1
 i: 3, name: ImageRecordUInt8Iter_v1
 i: 4, name: MNISTIter
 i: 5, name: ImageRecordIter
 i: 6, name: ImageRecordUInt8Iter
 i: 7, name: LibSVMIter

所以在我看来,对于C ++,没有NDArray迭代器,最简单的解决方案是将我的数据写入csv文件,然后再将其加载到MXDataIter(CSVIter)。另一种可能性是手动将数据分成批量NDArray并将其提供给培训,但这也感觉很笨拙。

1 个答案:

答案 0 :(得分:1)

不幸的是,C ++包中没有NDArrayIter。

但是如果你真的需要的话,我会说它应该不难实现。看看它是如何在Python中实现的,也许您可​​以通过C ++实现回馈社区 - https://github.com/apache/incubator-mxnet/blob/fe5b56e419d454dc8f42f0307f53ced133804ca7/python/mxnet/io.py#L544