找不到:FeedInputs:无法找到Feed输出TensorFlow

时间:2016-01-14 17:51:25

标签: c++ model tensorflow

我在这个网站上尝试使用Tensorflow保存模型的例子: https://medium.com/jim-fleming/loading-a-tensorflow-graph-with-the-c-api-4caaff88463f#.ji310n4zo

效果很好。但它不保存变量 a b 的值,因为它只保存图形而不保存变量。我试图替换以下行:

tf.train.write_graph(sess.graph_def, 'models/', 'graph.pb', as_text=False)

saver.save(sess, 'models/graph', global_step=0)

当然在创建了保护对象之后。它不起作用,它输出:

未找到:FeedInputs:无法找到Feed输出

我检查了节点加载的节点,它们只是:

  

_Source

     

_SINK

在write_graph函数中,然后在C ++中加载模型,我加载了以下节点:

  

_Source

     

_SINK

     

保存/ restore_slice_1 / shape_and_slice

     

保存/ restore_slice_1 / tensor_name

     

保存/ restore_slice / shape_and_slice

     

保存/ restore_slice / tensor_name

     

保存/保存/ shapes_and_slices

     

保存/保存/ tensor_names

     

保存/ CONST

     

保存/ restore_slice_1

     

保存/ restore_slice

     

B'/ P>      

保存/ Assign_1

     

B /读

     

B / initial_value

     

B /分配

     

     

保存/分配

     

保存/ RESTORE_ALL

     

保存/保存

     

保存/ control_dependency

     

A /读

     

C

     

A / initial_value

     

A /分配

     

初​​始化

     

张量

甚至saver.save()创建的图形文件比write_graph创建的图形文件小得多165B。

2 个答案:

答案 0 :(得分:3)

我不确定这是否是解决问题的最佳方式,但至少可以解决问题。

由于write_graph也可以存储常量的值,因此在使用write_graph函数编写图形之前,我将以下代码添加到python中:

for v in tf.trainable_variables():
    vc = tf.constant(v.eval())
    tf.assign(v, vc, name="assign_variables")

这会创建存储变量的常量'训练后的数值,然后创建张量" assign_variables "将它们分配给变量。现在,当你调用write_graph时,它将存储变量'文件中的值。

唯一剩下的部分是调用这些张量" assign_variables "在c代码中,以确保为变量分配存储在文件中的常量值。这是一种方法:

      Status status = NewSession(SessionOptions(), &session);
      std::vector<tensorflow::Tensor> outputs;
      for(int i = 0;status.ok(); i++) {
        char name[100];
        if (i==0)
            sprintf(name, "assign_variables");
        else
            sprintf(name, "assign_variables_%d", i);

        status = session->Run({}, {name}, {}, &outputs);
      }

答案 1 :(得分:1)

还有另一种方法可以通过调用save/restore_all操作来恢复变量,该操作应该出现在图表中:

std::vector<tensorflow::Tensor> outputs;
Tensor checkpoint_filepath(DT_STRING, TensorShape());
checkpoint_filepath.scalar<std::string>()() = "path to the checkpoint file";
status = session->Run( {{ "save/Const", checkpoint_filepath },}, 
                       {}, {"save/restore_all"}, &outputs);