如何在C ++ TensorFlow中使用CSV中的加载数据?

时间:2017-01-19 02:00:26

标签: c++ machine-learning tensorflow

我正在尝试将使用Python训练的模型加载到C ++中并从CSV中对一些数据进行分类。我找到了这个教程:

https://medium.com/@hamedmp/exporting-trained-tensorflow-models-to-c-the-right-way-cf24b609d183#.3bmbyvby0

这引出了我的这段代码示例:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/label_image/main.cc

对我来说这看起来很有希望。但是,我想要加载的数据是CSV,而不是图像文件,所以我试图重写ReadTensorFromImageFile函数。我能够找到一个类DecodeCSV,但它与示例代码中的DecodePNG和DecodeJpeg类略有不同,我最终得到的是OutputList而不是Output。使用列表中的[]运算符似乎会导致我的程序崩溃。如果有人碰巧知道如何处理这个问题,我将不胜感激。他是对代码的相关更改:

// inside ReadTensorFromText
Output image_reader;
    std::initializer_list<Input>* x = new std::initializer_list<Input>;
    ::tensorflow::ops::InputList defaults = ::tensorflow::ops::InputList(*x);
    OutputList image_read_list;
    image_read_list = DecodeCSV(root.WithOpName("csv_reader"), file_reader, defaults).output;
    // Now cast the image data to float so we can do normal math on it.
    // image_read_list.at(0) crashes the executable.
    auto float_caster =
        Cast(root.WithOpName("float_caster"), image_read_list.at(0), tensorflow::DT_FLOAT);

0 个答案:

没有答案