OSX py3.6 tf1.8
我是张量流初学者。我尝试为mnist训练模型。还原模型时出现错误。
from datetime import datetime
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
save_model_path = 'mnist_model/model.ckpt'
def train():
learning_rate = 0.05
batch_size = 100
max_epochs = 100
num_of_batch = int(mnist.train.num_examples / batch_size)
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
y = tf.placeholder(tf.float32, shape=[None, 10], name='y')
print(X.name, y.name)
W = tf.get_variable(shape=[784, 10], name='weight')
b = tf.get_variable(initializer=tf.zeros([10]), name='bais')
tf.summary.histogram("weights", W)
tf.summary.histogram("biases", b)
with tf.name_scope('pred'):
y_pred = tf.nn.softmax(tf.matmul(X, W) + b, name='predict')
print(y_pred.name)
with tf.name_scope('loss'):
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
tf.summary.scalar('loss', loss)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
with tf.name_scope('acc'):
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='acc')
print(accuracy.name)
merged_summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
loss_avg = 0
writer = tf.summary.FileWriter('mnist/{}'.format(now), sess.graph)
for epoch in range(max_epochs):
for i in range(num_of_batch):
batch_x, batch_y = mnist.train.next_batch(batch_size)
summary_str, _, l = sess.run([merged_summary_op, optimizer, loss], feed_dict={X: batch_x, y: batch_y})
loss_avg += l
global_step = epoch * num_of_batch + i
writer.add_summary(summary_str, global_step)
if global_step % 100 == 0:
print('Epoch {}: {} save model'.format(epoch, i))
# save model in halfway
saver.save(sess, save_model_path, global_step=global_step)
loss_avg /= num_of_batch
print('Epoch {}: Loss {}'.format(epoch, loss_avg))
print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))
saver.save(sess, save_model_path)
def predict(import_from_meta=False):
if import_from_meta:
meta_path = 'mnist_model/model.ckpt.meta'
checkpoint_path = 'mnist_model'
else:
# stupid var WTF ValueError: No variables to save
_ = tf.Variable(0)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if import_from_meta:
saver = tf.train.import_meta_graph(meta_path)
saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))
else:
saver.restore(sess, save_model_path)
graph = tf.get_default_graph()
X = graph.get_tensor_by_name('X:0')
y = graph.get_tensor_by_name('y:0')
accuracy = graph.get_tensor_by_name('acc/acc:0')
print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))
pred = graph.get_tensor_by_name('pred/predict:0')
import matplotlib.pyplot as plt
i = 90
img_orign = mnist.train.images[i]
img = img_orign.reshape((28, 28))
plt.imshow(img, cmap='gray')
plt.title(mnist.train.labels[i])
plt.show()
a = sess.run(pred, feed_dict={X: img_orign.reshape(-1, 784)})
print(a.shape)
import numpy as np
print(np.argmax(a))
def check_ckpt():
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(save_model_path, tensor_name='', all_tensors=True)
if __name__ == '__main__':
# train()
predict(import_from_meta=False)
# check_ckpt()
使用predict(import_from_meta=False)
错误:
WARNING:tensorflow:From /Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
2018-11-08 16:53:40.482921: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Not found: Key Variable not found in checkpoint
Traceback (most recent call last):
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1322, in _do_call
return fn(*args)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 115, in <module>
predict(import_from_meta=False)
File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 92, in predict
saver.restore(sess, save_model_path)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1802, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key Variable not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
Caused by op 'save/RestoreV2', defined at:
File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 115, in <module>
predict(import_from_meta=False)
File "/Users/wyx/project/learn-sktf/tf/mnist_clf.py", line 84, in predict
saver = tf.train.Saver()
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1338, in __init__
self.build()
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1347, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1384, in _build
build_save=build_save, build_restore=build_restore)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 835, in _build_internal
restore_sequentially, reshape)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 472, in _AddRestoreOps
restore_sequentially)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 886, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1463, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3392, in create_op
op_def=op_def)
File "/Users/wyx/project/learn-sktf/.env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1718, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
NotFoundError (see above for traceback): Key Variable not found in checkpoint
[[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
很奇怪,当我使用predict(import_from_meta=True)
时,我会得到正确答案。
然后,我尝试从check_ckpt()
到inspect variables in a checkpoint。我找不到合适的tensor_name。真好笑,name_scope
就像个玩笑。
tensor_name: bais
[-24.933702 1.7660792 3.697866 -14.221888 8.967291 42.149403
-3.2693458 23.876926 -30.643892 -3.7861202]
tensor_name: bais/Adam
[ 3.1726879e-07 -5.2043208e-07 3.4227469e-05 2.5119303e-07
-2.0110610e-04 1.8493415e-04 -3.6275055e-06 -1.4343520e-04
-7.2765622e-05 2.0172486e-04]
tensor_name: bais/Adam_1
[5.2586905e-08 8.9204484e-08 1.5440051e-07 2.9412612e-07 2.4380788e-07
3.4676964e-07 8.7062219e-08 1.8839150e-07 4.3878950e-07 4.2466107e-07]
tensor_name: loss/beta1_power
0.0
tensor_name: loss/beta2_power
1.2639432e-24
tensor_name: weight
[[-0.03386476 0.03485525 -0.03267809 ... -0.08548199 0.00565728
-0.01887459]
[ 0.00370622 0.08523928 0.05811391 ... -0.07838921 0.05987743
0.074329 ]
[ 0.0180116 0.04400793 -0.0260816 ... 0.00807328 0.06537797
-0.07446742]
...
[-0.00665552 -0.03390152 -0.03889231 ... -0.01871967 -0.05968629
0.07207178]
[ 0.01317277 0.03459686 -0.03268962 ... 0.07082433 0.03290742
0.03172391]
[-0.04514085 -0.03013236 0.01006595 ... 0.01906221 0.02611361
0.04348358]]
tensor_name: weight/Adam
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
tensor_name: weight/Adam_1
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]]
Process finished with exit code 0
那么我的代码有什么问题?为什么只想恢复模型,为什么必须在tf.train.Saver
之前创建变量?