为什么多端口模型无法导入C ++?

时间:2017-10-20 02:21:52

标签: mxnet

这是生成简单多端口MLP模型的python代码 - 两个输入,两个输出。使用HybridBlockerport函数可以在C ++中使用

Net Graph:

enter image description here enter image description here

from mxnet import nd
from mxnet.gluon import nn
import mxnet as mx

class HybridNet(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(HybridNet, self).__init__(**kwargs)
        with self.name_scope():
            self.dense0 = nn.Dense(3)
            self.dense1 = nn.Dense(3)
            self.dense2 = nn.Dense(6)

    def hybrid_forward(self, F,x,y):
        result1 = F.relu(self.dense0(x))+F.relu(self.dense1(y))
        result2 = F.relu(self.dense2(result1))
        return [result1,result2]

net = HybridNet()
net.initialize()
net.hybridize()
x = nd.random.normal(shape=(4,3))
y = nd.random.normal(shape=(4,5))
res=net(x,y)
print "output1:",res[0]
print "output2:",res[1]
net.export('model')

我们可以重新导入模型以检查天气导出是否正确。您可以看到两个结果是相同的。

from collections import namedtuple
sym = mx.symbol.load('model-symbol.json') 
mod=mx.mod.Module(symbol=sym,data_names=['data0','data1'])
mod.bind(data_shapes=[('data0',(1,3)),('data1',(1,5))])
mod.load_params('model-0000.params')
Batch=namedtuple('Batch',['data'])
mod.forward(Batch(data=[x,y]))
print mod.get_outputs()

enter image description here

查看输出结果

sym.list_outputs()
  

['hybridnet0__plus0_output','hybridnet0_relu2_output']

这是C ++代码的第一部分,它会引发错误。我确保num_input_nodesnum_output_nodes都是两个。并使用MXPredCreatePartialOut来自定义我的多任务输出

enter image description here

#include <mxnet/c_predict_api.h>

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <assert.h>

// Read file to buffer
class BufferFile {
public:
    std::string file_path_;
    int length_;
    char* buffer_;

    explicit BufferFile(std::string file_path)
        :file_path_(file_path) {

        std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
        if (!ifs) {
            std::cerr << "Can't open the file. Please check " << file_path << ". \n";
            length_ = 0;
            buffer_ = NULL;
            return;
        }

        ifs.seekg(0, std::ios::end);
        length_ = ifs.tellg();
        ifs.seekg(0, std::ios::beg);
        std::cout << file_path.c_str() << " ... " << length_ << " bytes\n";

        buffer_ = new char[sizeof(char) * length_];
        ifs.read(buffer_, length_);
        ifs.close();
    }

    int GetLength() {
        return length_;
    }
    char* GetBuffer() {
        return buffer_;
    }

    ~BufferFile() {
        if (buffer_) {
            delete[] buffer_;
            buffer_ = NULL;
        }
    }
};

int main(int argc, char* argv[]) {

    // Models path for your model, you have to modify it
    std::string json_file = "./model-symbol.json";
    std::string param_file = "./model-0000.params";

    BufferFile json_data(json_file);
    BufferFile param_data(param_file);

    // Parameters
    int dev_type = 1;  // 1: cpu, 2: gpu
    int dev_id = 1;  // arbitrary.
    mx_uint num_input_nodes = 2;
    mx_uint num_output_nodes = 2;

    const char* input_key[2] = { "data0" , "data1" };
    const char** input_keys = input_key;
    const char* output_key[2] = { "hybridnet0__plus0" , "hybridnet0_relu2" };
    const char** output_keys = output_key;

    // input-dims
    int data0_len = 3;
    int data1_len = 5;
    const mx_uint input_shape_indptr[4] = { 0,2,2,4 };
    const mx_uint input_shape_data[4] = {1,static_cast<mx_uint>(data0_len),1,static_cast<mx_uint>(data1_len) };
    PredictorHandle pred_hnd = 0;

    if (json_data.GetLength() == 0 || param_data.GetLength() == 0)
        return -1;

    // Create Predictor
    assert(0 == MXPredCreatePartialOut(
        (const char*)json_data.GetBuffer(),
        (const char*)param_data.GetBuffer(),
        static_cast<size_t>(param_data.GetLength()),
        dev_type,
        dev_id,
        num_input_nodes,
        input_keys,
        input_shape_indptr,
        input_shape_data,
        num_output_nodes,
        output_keys,
        &pred_hnd));   //ERROR HERE
    assert(pred_hnd);

    return 0;
}

1 个答案:

答案 0 :(得分:1)

看起来这条线路是错误的。 const mx_uint input_shape_indptr[4] = { 0,2,2,4 };

将其更改为const mx_uint input_shape_indptr[3] = { 0,2,4 };