如何预测MXNet的C ++中的小批量示例?

时间:2017-10-30 02:03:40

标签: mxnet

在python界面中,我们可以使用小批量示例进行预测,如net([[1,2],[3,4],[5,6]])

但是在C ++中,我找不到办法做到这一点。

假设调用网络预测单个示例需要10ms。如果有10000个例子需要进行预测,那就是100s

void OneInputOneOutputPredict(PredictorHandle pred_hnd, std::vector<mx_float> vector_data, std::vector<mx_float> &output)
{
    MXPredSetInput(pred_hnd, "data", vector_data.data(), vector_data.size());

    // Do Predict Forward
    MXPredForward(pred_hnd);

    mx_uint output_index = 0;
    mx_uint *shape = 0;
    mx_uint shape_len;
    MXPredGetOutputShape(pred_hnd, output_index, &shape, &shape_len);
    size_t size = 1;
    for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];

    std::vector<float> data(size);
    assert(0 == MXPredGetOutput(pred_hnd, output_index, &(data[0]), size));
    output = data;
}

//very long time
for(int step=0;step<10000;step++)
    OneInputOneOutputPredict(pred_hnd, vector_data, vector_label);

我们可以使用C ++中的代码或某种方式对其进行矢量化,以使其快速预测吗?

1 个答案:

答案 0 :(得分:0)

原本 input_shape_data看起来像这样

const mx_uint input_shape_data[4] = {1, static_cast<mx_uint>(data_len)};

现在,如果我想预测一个小批量(批量大小3)

const mx_uint input_shape_data[4] = {3, static_cast<mx_uint>(data_len)};

如果使用seq2seq模型。如果数据看起来像[[1,2],[3,4],[5,6]],现在只将其展平为列表{1,2,3,4,5,6},那么一切正常