在学习的张量流模型中无法读取训练的权重

时间:2017-12-18 23:09:10

标签: python tensorflow

我正在尝试从Tensorflow模型中读取通过训练获得的值,但我只获得权重的初始值。我正在使用https://gist.github.com/saitodev/8532cf9e94a9490f75a9bce678751aec中的示例代码,只添加了用于打印W和b值的代码。 我得到偏差值(b),但是权重(W)都是零,这是不可能的,因为学习模型是实际工作的(92%准确度)。我还尝试为权重设置trainable = False,停止模型以了解W确实需要更新才能工作。我应该如何阅读学习重量的值?我的方法有什么问题?

代码:

from __future__ import print_function

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf

x = tf.placeholder(tf.float32, [None, 784]) 

W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)

y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())

  max_steps = 1000
  for step in range(max_steps):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    if (step % 100) == 0:
      print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

  print(max_steps, sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
  print("W=")
  print(sess.run(W))
  print("b=")
  print(sess.run(b))
  print(max_steps, sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

结果:

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
0 0.3348
100 0.8918
200 0.9037
300 0.906
400 0.9098
500 0.9137
600 0.9168
700 0.9147
800 0.9134
900 0.9193
1000 0.9193
W=
[[ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 ..., 
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]
 [ 0.  0.  0. ...,  0.  0.  0.]]
b=
[-0.38804549  0.35967571  0.09746896 -0.28238639  0.03597458  1.31636047
 -0.11613782  0.64165515 -1.42244864 -0.24211763]
1000 0.9193

1 个答案:

答案 0 :(得分:0)

在您的代码中,sess.run(tf.initialize_all_variables())将以指定的方式初始化变量,例如零。因此,您将获得完整的零W

您展示的代码仅培训模型,不包括保存或加载模型。如果要加载预先训练的模型,则应添加如下代码:

saver = tf.train.Saver()
saver.restore(sess, model_checkpoint_path)

代码不是自己初始化它们,而是帮助您初始化图表中在检查点中找到的所有(可训练)变量。

此外,由参数trainable=True初始化的变量(tensor)将可训练并添加到可训练集合中。

注意,您应该确保要恢复的图形与您定义的图形相同,包括网格结构,操作名称,张量名称等。

您应该查看官方文件:Saving and Restoring