我用tens命令保存了一个训练有素的模型:
imported_meta = tf.train.import_meta_graph("/tmp/new_trained_model.ckpt.meta")
imported_meta.restore(sess, tf.train.latest_checkpoint(checkpoint_dir="/tmp/,latest_filename="checkpoint"))
然后,我使用以下命令加载模型:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(one_hot_y, 1))
#logits come from the model,there is no error,so didn't post that code
accuracy_operation = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#saver = tf.train.Saver()
def evaluate(X_data, y_data):
num_examples = len(X_data)
total_accuracy = 0
sess = tf.get_default_session()
for offset in range(0, num_examples, BATCH_SIZE):
batch_x, batch_y = X_data[offset:offset+BATCH_SIZE], y_data[offset:offset+BATCH_SIZE]
accuracy = sess.run(accuracy_operation, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.0})
total_accuracy += (accuracy * len(batch_x))
return total_accuracy / num_examples
test_accuracy = evaluate(X_test, y_test)
现在,为了评估准确性,使用了以下功能:
FailedPreconditionError (see above for traceback): Attempting to use uninitialized value Variable_12
[[Node: Variable_12/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_12"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Variable_12)]]
但是上面的函数给出了错误:
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file("/tmp/new_trained_model.ckpt", tensor_name='',all_tensor_names='', all_tensors=True)
但是,当我从图表中打印张量时,它会显示Variable_12矩阵:
tensor_name: Variable_12
[[-0.1013797 -0.08079438 -0.05904691 ... -0.07798752 -0.08208387
-0.18532619]
[ 0.10919656 -0.06162841 -0.19453178 ... -0.03241748 0.1023232
0.07120663]
[-0.10920436 0.00233169 -0.08879709 ... -0.09918057 -0.02546161
0.00903581]
...
[ 0.13858072 0.13791025 -0.12322884 ... -0.15006843 0.00103891
0.06663229]
[-0.14043045 0.14039241 0.15048873 ... 0.07272678 0.00470365
0.0273346 ]
[-0.10976157 -0.10873327 -0.16460624 ... -0.16509598 0.1124685
-0.08858881]]
Variable_12 :(仅显示输出中的一个变量)
public String getPost(){
new GraphRequest(
AccessToken.getCurrentAccessToken(), "me?fields=friends,name", null, HttpMethod.GET,
new GraphRequest.Callback() {
public void onCompleted(GraphResponse response) {
JSONObject object = response.getJSONObject();
try {
friends = object.getJSONObject("friends").getJSONObject("summary").getString("total_count");
} catch (JSONException e) {
e.printStackTrace();
};
}
}
).executeAsync();
return friends;
}
任何人都可以解释为什么显示未初始化的错误,因为值是由inspect_checkpoint确认的吗?
感谢您的时间。
答案 0 :(得分:1)
问题似乎是你构建了两个计算图。
首先,您提到您“执行了定义模型体系结构的部分”。这将为您的模型创建计算图。
然后,你也做了
imported_meta = tf.train.import_meta_graph("/tmp/new_trained_model.ckpt.meta")
这将为您的模型创建第二个计算图。
根据您执行这些操作的准确程度,“计算图”可以位于一个或两个单独的“图形”对象中。在任何情况下,imported_meta.restore
初始化导入的(第二)计算图的变量,但是您调用session.run()
来计算第一个计算图中的张量。没有人在第一个计算图中初始化变量。
如果您已经创建了图表,则修复不会导入(元)图表。只要变量名称和形状没有改变,您就可以使用Saver
来恢复变量的值,而无需再创建任何变量或操作。