在图张量流中找不到feed_devices或fetch_devices中指定的

时间:2018-08-04 03:25:29

标签: python c++ tensorflow gpu

我使用tensorflow C ++ API。我在GPU上训练模型并执行此代码(用于预测)

#include<iostream>
using namespace tensorflow;

tensorflow::Tensor loadImage(tensorflow::string fname){
    tensorflow::int32 width = 224;
    tensorflow::int32 height = 224;
    tensorflow::int32 nData = 1;
    tensorflow::int32 nVec = width*height;
    tensorflow::int32 channels = 3;
    auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, height, width, channels}));
    auto mat = tensor.tensor<float, 4>();
    std::ifstream fin(fname, std::ios_base::in | std::ios_base::binary);
    assert(!fin.fail());
    boost::iostreams::filtering_istream s;
    s.push(fin);
    char c;
    for(int i=0;i<nData;i++){
        for(int h=0;h<height;h++){
            for(int w=0;w<width;w++){
                for(int j=0;j<channels;j++){
                    s.get(c);
                    mat(i, h, w, j) = static_cast<float>(static_cast<uint8_t>(c)) / 255.0;
                }
            }
        }
    }std::cout << "Image Loaded" << std::endl;
    return tensor;
}


int main(int argc, char* argv[]) {
  Session* session;
  Status status = NewSession(SessionOptions(), &session);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  GraphDef graph_def;
  status = ReadBinaryProto(Env::Default(), "graph.pb", &graph_def);

  if (!status.ok()) {
    std::cout << "Status Not OK" << std::endl;
    std::cout << status.ToString() << "\n";
    return 1;
  }
  else{
      std::cout << "Graph Loaded" << std::endl;
  }
  status = session->Create(graph_def);
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }
  else{
      std::cout << "Create End" << std::endl;
  }

  std::string fname = "test.jpg";
  tensorflow::Tensor img = loadImage(fname);
  std::vector<std::pair<tensorflow::string, tensorflow::Tensor>> inputs = {{"img0001", img }};

  std::vector<tensorflow::Tensor> outputs;

  std::cout << "Start Run" << std::endl;
  status = session->Run(inputs, {"output_node0"}, {}, &outputs);
  std::cout << "End Run" << std::endl;
  if (!status.ok()) {
    std::cout << status.ToString() << "\n";
    return 1;
  }

  std::cout << outputs[0].DebugString() << "\n";
  std::cout << output_c() << "\n"; // 30

  session->Close();
  return 0;
}

但是,我收到了类似这样的未知错误。

Invalid argument: Tensor img0001:0, specified in either feed_devices or fetch_devices was not found in the Graph

此错误发生在此代码。

session->Run(inputs, {"output_node0"}, {}, &outputs);

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/graph_execution_state.cc 在此站点上说明,当节点名称不相等时会发生此错误。 我是通过keras而不是tensorflow创建模型的。 因此,我通过此代码将模型从keras转换为tensorflow。 https://github.com/icchi-h/keras_to_tensorflow/blob/master/keras_to_tensorflow.py

我想它与GPU训练有关。 https://github.com/tensorflow/tensorflow/issues/5902

但是,对此我无法证实。

请教我这个问题的解决方法。

0 个答案:

没有答案