我使用tensorflow C ++ API。我在GPU上训练模型并执行此代码(用于预测)
#include<iostream>
using namespace tensorflow;
tensorflow::Tensor loadImage(tensorflow::string fname){
tensorflow::int32 width = 224;
tensorflow::int32 height = 224;
tensorflow::int32 nData = 1;
tensorflow::int32 nVec = width*height;
tensorflow::int32 channels = 3;
auto tensor = tensorflow::Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({1, height, width, channels}));
auto mat = tensor.tensor<float, 4>();
std::ifstream fin(fname, std::ios_base::in | std::ios_base::binary);
assert(!fin.fail());
boost::iostreams::filtering_istream s;
s.push(fin);
char c;
for(int i=0;i<nData;i++){
for(int h=0;h<height;h++){
for(int w=0;w<width;w++){
for(int j=0;j<channels;j++){
s.get(c);
mat(i, h, w, j) = static_cast<float>(static_cast<uint8_t>(c)) / 255.0;
}
}
}
}std::cout << "Image Loaded" << std::endl;
return tensor;
}
int main(int argc, char* argv[]) {
Session* session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
GraphDef graph_def;
status = ReadBinaryProto(Env::Default(), "graph.pb", &graph_def);
if (!status.ok()) {
std::cout << "Status Not OK" << std::endl;
std::cout << status.ToString() << "\n";
return 1;
}
else{
std::cout << "Graph Loaded" << std::endl;
}
status = session->Create(graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
else{
std::cout << "Create End" << std::endl;
}
std::string fname = "test.jpg";
tensorflow::Tensor img = loadImage(fname);
std::vector<std::pair<tensorflow::string, tensorflow::Tensor>> inputs = {{"img0001", img }};
std::vector<tensorflow::Tensor> outputs;
std::cout << "Start Run" << std::endl;
status = session->Run(inputs, {"output_node0"}, {}, &outputs);
std::cout << "End Run" << std::endl;
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
std::cout << outputs[0].DebugString() << "\n";
std::cout << output_c() << "\n"; // 30
session->Close();
return 0;
}
但是,我收到了类似这样的未知错误。
Invalid argument: Tensor img0001:0, specified in either feed_devices or fetch_devices was not found in the Graph
此错误发生在此代码。
session->Run(inputs, {"output_node0"}, {}, &outputs);
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/common_runtime/graph_execution_state.cc 在此站点上说明,当节点名称不相等时会发生此错误。 我是通过keras而不是tensorflow创建模型的。 因此,我通过此代码将模型从keras转换为tensorflow。 https://github.com/icchi-h/keras_to_tensorflow/blob/master/keras_to_tensorflow.py
我想它与GPU训练有关。 https://github.com/tensorflow/tensorflow/issues/5902
但是,对此我无法证实。
请教我这个问题的解决方法。